zoukankan      html  css  js  c++  java
  • InceptionNet实现cifar10数据集

    import tensorflow as tf
    import os
    import numpy as np
    from matplotlib import pyplot as plt
    from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense, 
        GlobalAveragePooling2D
    from tensorflow.keras import Model
    
    np.set_printoptions(threshold=np.inf)
    
    cifar10 = tf.keras.datasets.cifar10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    
    class ConvBNRelu(Model):
        def __init__(self, ch, kernelsz=3, strides=1, padding='same'):
            super(ConvBNRelu, self).__init__()
            self.model = tf.keras.models.Sequential([
                Conv2D(ch, kernelsz, strides=strides, padding=padding),
                BatchNormalization(),
                Activation('relu')
            ])
    
        def call(self, x):
            x = self.model(x, training=False) #在training=False时,BN通过整个训练集计算均值、方差去做批归一化,training=True时,通过当前batch的均值、方差去做批归一化。推理时 training=False效果好
            return x
    
    
    class InceptionBlk(Model):
        def __init__(self, ch, strides=1):
            super(InceptionBlk, self).__init__()
            self.ch = ch
            self.strides = strides
            self.c1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
            self.c2_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
            self.c2_2 = ConvBNRelu(ch, kernelsz=3, strides=1)
            self.c3_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
            self.c3_2 = ConvBNRelu(ch, kernelsz=5, strides=1)
            self.p4_1 = MaxPool2D(3, strides=1, padding='same')
            self.c4_2 = ConvBNRelu(ch, kernelsz=1, strides=strides)
    
        def call(self, x):
            x1 = self.c1(x)
            x2_1 = self.c2_1(x)
            x2_2 = self.c2_2(x2_1)
            x3_1 = self.c3_1(x)
            x3_2 = self.c3_2(x3_1)
            x4_1 = self.p4_1(x)
            x4_2 = self.c4_2(x4_1)
            # concat along axis=channel
            x = tf.concat([x1, x2_2, x3_2, x4_2], axis=3)
            return x
    
    
    class Inception10(Model):
        def __init__(self, num_blocks, num_classes, init_ch=16, **kwargs):
            super(Inception10, self).__init__(**kwargs)
            self.in_channels = init_ch
            self.out_channels = init_ch
            self.num_blocks = num_blocks
            self.init_ch = init_ch
            self.c1 = ConvBNRelu(init_ch)
            self.blocks = tf.keras.models.Sequential()
            for block_id in range(num_blocks):
                for layer_id in range(2):
                    if layer_id == 0:
                        block = InceptionBlk(self.out_channels, strides=2)
                    else:
                        block = InceptionBlk(self.out_channels, strides=1)
                    self.blocks.add(block)
                # enlarger out_channels per block
                self.out_channels *= 2
            self.p1 = GlobalAveragePooling2D()
            self.f1 = Dense(num_classes, activation='softmax')
    
        def call(self, x):
            x = self.c1(x)
            x = self.blocks(x)
            x = self.p1(x)
            y = self.f1(x)
            return y
    
    
    model = Inception10(num_blocks=2, num_classes=10)
    
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=['sparse_categorical_accuracy'])
    
    checkpoint_save_path = "./checkpoint/Inception10.ckpt"
    if os.path.exists(checkpoint_save_path + '.index'):
        print('-------------load the model-----------------')
        model.load_weights(checkpoint_save_path)
    
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                     save_weights_only=True,
                                                     save_best_only=True)
    
    history = model.fit(x_train, y_train, batch_size=1024, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                        callbacks=[cp_callback])
    model.summary()
    
    # print(model.trainable_variables)
    file = open('./weights.txt', 'w')
    for v in model.trainable_variables:
        file.write(str(v.name) + '
    ')
        file.write(str(v.shape) + '
    ')
        file.write(str(v.numpy()) + '
    ')
    file.close()
    
    ###############################################    show   ###############################################
    
    # 显示训练集和验证集的acc和loss曲线
    acc = history.history['sparse_categorical_accuracy']
    val_acc = history.history['val_sparse_categorical_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    plt.subplot(1, 2, 1)
    plt.plot(acc, label='Training Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()

    注意InceptionNet网络和其他网络思想

  • 相关阅读:
    用leangoo看板工具实施多团队大规模敏捷开发
    单团队的Scrum敏捷开发-leangoo
    放弃在每日站会上按成员逐个发言
    Leangoo思维导图做OKR目标管理
    好用的思维导图软件(程序员必备)
    好用免费的思维导图工具,多人协作共同编辑
    什么是Scrum燃尽图?
    Scrum中文网-团队需要Scrum Master做这六件事
    项目管理工具Leangoo自定义字段的应用
    实施敏捷开发中,选择看板管理工具的几个要点
  • 原文地址:https://www.cnblogs.com/python2/p/13599044.html
Copyright © 2011-2022 走看看