zoukankan      html  css  js  c++  java
  • InceptionNet实现cifar10数据集

    import tensorflow as tf
    import os
    import numpy as np
    from matplotlib import pyplot as plt
    from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense, 
        GlobalAveragePooling2D
    from tensorflow.keras import Model
    
    np.set_printoptions(threshold=np.inf)
    
    cifar10 = tf.keras.datasets.cifar10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    
    class ConvBNRelu(Model):
        def __init__(self, ch, kernelsz=3, strides=1, padding='same'):
            super(ConvBNRelu, self).__init__()
            self.model = tf.keras.models.Sequential([
                Conv2D(ch, kernelsz, strides=strides, padding=padding),
                BatchNormalization(),
                Activation('relu')
            ])
    
        def call(self, x):
            x = self.model(x, training=False) #在training=False时,BN通过整个训练集计算均值、方差去做批归一化,training=True时,通过当前batch的均值、方差去做批归一化。推理时 training=False效果好
            return x
    
    
    class InceptionBlk(Model):
        def __init__(self, ch, strides=1):
            super(InceptionBlk, self).__init__()
            self.ch = ch
            self.strides = strides
            self.c1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
            self.c2_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
            self.c2_2 = ConvBNRelu(ch, kernelsz=3, strides=1)
            self.c3_1 = ConvBNRelu(ch, kernelsz=1, strides=strides)
            self.c3_2 = ConvBNRelu(ch, kernelsz=5, strides=1)
            self.p4_1 = MaxPool2D(3, strides=1, padding='same')
            self.c4_2 = ConvBNRelu(ch, kernelsz=1, strides=strides)
    
        def call(self, x):
            x1 = self.c1(x)
            x2_1 = self.c2_1(x)
            x2_2 = self.c2_2(x2_1)
            x3_1 = self.c3_1(x)
            x3_2 = self.c3_2(x3_1)
            x4_1 = self.p4_1(x)
            x4_2 = self.c4_2(x4_1)
            # concat along axis=channel
            x = tf.concat([x1, x2_2, x3_2, x4_2], axis=3)
            return x
    
    
    class Inception10(Model):
        def __init__(self, num_blocks, num_classes, init_ch=16, **kwargs):
            super(Inception10, self).__init__(**kwargs)
            self.in_channels = init_ch
            self.out_channels = init_ch
            self.num_blocks = num_blocks
            self.init_ch = init_ch
            self.c1 = ConvBNRelu(init_ch)
            self.blocks = tf.keras.models.Sequential()
            for block_id in range(num_blocks):
                for layer_id in range(2):
                    if layer_id == 0:
                        block = InceptionBlk(self.out_channels, strides=2)
                    else:
                        block = InceptionBlk(self.out_channels, strides=1)
                    self.blocks.add(block)
                # enlarger out_channels per block
                self.out_channels *= 2
            self.p1 = GlobalAveragePooling2D()
            self.f1 = Dense(num_classes, activation='softmax')
    
        def call(self, x):
            x = self.c1(x)
            x = self.blocks(x)
            x = self.p1(x)
            y = self.f1(x)
            return y
    
    
    model = Inception10(num_blocks=2, num_classes=10)
    
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=['sparse_categorical_accuracy'])
    
    checkpoint_save_path = "./checkpoint/Inception10.ckpt"
    if os.path.exists(checkpoint_save_path + '.index'):
        print('-------------load the model-----------------')
        model.load_weights(checkpoint_save_path)
    
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                     save_weights_only=True,
                                                     save_best_only=True)
    
    history = model.fit(x_train, y_train, batch_size=1024, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                        callbacks=[cp_callback])
    model.summary()
    
    # print(model.trainable_variables)
    file = open('./weights.txt', 'w')
    for v in model.trainable_variables:
        file.write(str(v.name) + '
    ')
        file.write(str(v.shape) + '
    ')
        file.write(str(v.numpy()) + '
    ')
    file.close()
    
    ###############################################    show   ###############################################
    
    # 显示训练集和验证集的acc和loss曲线
    acc = history.history['sparse_categorical_accuracy']
    val_acc = history.history['val_sparse_categorical_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    plt.subplot(1, 2, 1)
    plt.plot(acc, label='Training Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()

    注意InceptionNet网络和其他网络思想

  • 相关阅读:
    mysql 远程登录修改配置
    scrapy--分布式爬虫
    win10---cmd终端下连接ubantu--SSH SERVER服务
    将python环境打包成.txt文件
    ubantu安装python3虚拟环境
    selenium 自动化安装火狐谷歌插件
    mysql主从复制-读写分离-原理
    mysql主从复制原理
    mysql储存引擎
    mysql检查-优化-分析
  • 原文地址:https://www.cnblogs.com/python2/p/13599044.html
Copyright © 2011-2022 走看看