zoukankan      html  css  js  c++  java
  • resnet实现cifar10数据集分类

    import tensorflow as tf
    import os
    import numpy as np
    from matplotlib import pyplot as plt
    from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
    from tensorflow.keras import Model
    
    np.set_printoptions(threshold=np.inf)
    
    cifar10 = tf.keras.datasets.cifar10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    
    class ResnetBlock(Model):
    
        def __init__(self, filters, strides=1, residual_path=False):
            super(ResnetBlock, self).__init__()
            self.filters = filters
            self.strides = strides
            self.residual_path = residual_path
    
            self.c1 = Conv2D(filters, (3, 3), strides=strides, padding='same', use_bias=False)
            self.b1 = BatchNormalization()
            self.a1 = Activation('relu')
    
            self.c2 = Conv2D(filters, (3, 3), strides=1, padding='same', use_bias=False)
            self.b2 = BatchNormalization()
    
            # residual_path为True时,对输入进行下采样,即用1x1的卷积核做卷积操作,保证x能和F(x)维度相同,顺利相加
            if residual_path:
                self.down_c1 = Conv2D(filters, (1, 1), strides=strides, padding='same', use_bias=False)
                self.down_b1 = BatchNormalization()
    
            self.a2 = Activation('relu')
    
        def call(self, inputs):
            residual = inputs  # residual等于输入值本身,即residual=x
            # 将输入通过卷积、BN层、激活层,计算F(x)
            x = self.c1(inputs)
            x = self.b1(x)
            x = self.a1(x)
    
            x = self.c2(x)
            y = self.b2(x)
    
            if self.residual_path:
                residual = self.down_c1(inputs)
                residual = self.down_b1(residual)
    
            out = self.a2(y + residual)  # 最后输出的是两部分的和,即F(x)+x或F(x)+Wx,再过激活函数
            return out
    
    
    class ResNet18(Model):
    
        def __init__(self, block_list, initial_filters=64):  # block_list表示每个block有几个卷积层
            super(ResNet18, self).__init__()
            self.num_blocks = len(block_list)  # 共有几个block
            self.block_list = block_list
            self.out_filters = initial_filters
            self.c1 = Conv2D(self.out_filters, (3, 3), strides=1, padding='same', use_bias=False)
            self.b1 = BatchNormalization()
            self.a1 = Activation('relu')
            self.blocks = tf.keras.models.Sequential()
            # 构建ResNet网络结构
            for block_id in range(len(block_list)):  # 第几个resnet block
                for layer_id in range(block_list[block_id]):  # 第几个卷积层
    
                    if block_id != 0 and layer_id == 0:  # 对除第一个block以外的每个block的输入进行下采样
                        block = ResnetBlock(self.out_filters, strides=2, residual_path=True)
                    else:
                        block = ResnetBlock(self.out_filters, residual_path=False)
                    self.blocks.add(block)  # 将构建好的block加入resnet
                self.out_filters *= 2  # 下一个block的卷积核数是上一个block的2倍
            self.p1 = tf.keras.layers.GlobalAveragePooling2D()
            self.f1 = tf.keras.layers.Dense(10, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
    
        def call(self, inputs):
            x = self.c1(inputs)
            x = self.b1(x)
            x = self.a1(x)
            x = self.blocks(x)
            x = self.p1(x)
            y = self.f1(x)
            return y
    
    
    model = ResNet18([2, 2, 2, 2])
    
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=['sparse_categorical_accuracy'])
    
    checkpoint_save_path = "./checkpoint/ResNet18.ckpt"
    if os.path.exists(checkpoint_save_path + '.index'):
        print('-------------load the model-----------------')
        model.load_weights(checkpoint_save_path)
    
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                     save_weights_only=True,
                                                     save_best_only=True)
    
    history = model.fit(x_train, y_train, batch_size=1024, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                        callbacks=[cp_callback])
    model.summary()
    
    # print(model.trainable_variables)
    file = open('./weights.txt', 'w')
    for v in model.trainable_variables:
        file.write(str(v.name) + '
    ')
        file.write(str(v.shape) + '
    ')
        file.write(str(v.numpy()) + '
    ')
    file.close()
    
    ###############################################    show   ###############################################
    
    # 显示训练集和验证集的acc和loss曲线
    acc = history.history['sparse_categorical_accuracy']
    val_acc = history.history['val_sparse_categorical_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    plt.subplot(1, 2, 1)
    plt.plot(acc, label='Training Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()

    注意理解resnet网络结构

  • 相关阅读:
    JS面向(基于)对象编程--构造方法(函数)
    一个超简单的马里奥游戏
    JavaScript基于对象编程
    JavaScript基础之函数与数组
    Learn CSS
    The first day of HTML
    mysql cmd 无法登录
    datagrid 扩展 页脚 合计功能
    this高级应用
    (xxx.55).toFixed(1) 无法正确进位处理
  • 原文地址:https://www.cnblogs.com/python2/p/13599195.html
Copyright © 2011-2022 走看看