zoukankan      html  css  js  c++  java
  • resnet实现cifar10数据集分类

    import tensorflow as tf
    import os
    import numpy as np
    from matplotlib import pyplot as plt
    from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
    from tensorflow.keras import Model
    
    np.set_printoptions(threshold=np.inf)
    
    cifar10 = tf.keras.datasets.cifar10
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    
    
    class ResnetBlock(Model):
    
        def __init__(self, filters, strides=1, residual_path=False):
            super(ResnetBlock, self).__init__()
            self.filters = filters
            self.strides = strides
            self.residual_path = residual_path
    
            self.c1 = Conv2D(filters, (3, 3), strides=strides, padding='same', use_bias=False)
            self.b1 = BatchNormalization()
            self.a1 = Activation('relu')
    
            self.c2 = Conv2D(filters, (3, 3), strides=1, padding='same', use_bias=False)
            self.b2 = BatchNormalization()
    
            # residual_path为True时,对输入进行下采样,即用1x1的卷积核做卷积操作,保证x能和F(x)维度相同,顺利相加
            if residual_path:
                self.down_c1 = Conv2D(filters, (1, 1), strides=strides, padding='same', use_bias=False)
                self.down_b1 = BatchNormalization()
    
            self.a2 = Activation('relu')
    
        def call(self, inputs):
            residual = inputs  # residual等于输入值本身,即residual=x
            # 将输入通过卷积、BN层、激活层,计算F(x)
            x = self.c1(inputs)
            x = self.b1(x)
            x = self.a1(x)
    
            x = self.c2(x)
            y = self.b2(x)
    
            if self.residual_path:
                residual = self.down_c1(inputs)
                residual = self.down_b1(residual)
    
            out = self.a2(y + residual)  # 最后输出的是两部分的和,即F(x)+x或F(x)+Wx,再过激活函数
            return out
    
    
    class ResNet18(Model):
    
        def __init__(self, block_list, initial_filters=64):  # block_list表示每个block有几个卷积层
            super(ResNet18, self).__init__()
            self.num_blocks = len(block_list)  # 共有几个block
            self.block_list = block_list
            self.out_filters = initial_filters
            self.c1 = Conv2D(self.out_filters, (3, 3), strides=1, padding='same', use_bias=False)
            self.b1 = BatchNormalization()
            self.a1 = Activation('relu')
            self.blocks = tf.keras.models.Sequential()
            # 构建ResNet网络结构
            for block_id in range(len(block_list)):  # 第几个resnet block
                for layer_id in range(block_list[block_id]):  # 第几个卷积层
    
                    if block_id != 0 and layer_id == 0:  # 对除第一个block以外的每个block的输入进行下采样
                        block = ResnetBlock(self.out_filters, strides=2, residual_path=True)
                    else:
                        block = ResnetBlock(self.out_filters, residual_path=False)
                    self.blocks.add(block)  # 将构建好的block加入resnet
                self.out_filters *= 2  # 下一个block的卷积核数是上一个block的2倍
            self.p1 = tf.keras.layers.GlobalAveragePooling2D()
            self.f1 = tf.keras.layers.Dense(10, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
    
        def call(self, inputs):
            x = self.c1(inputs)
            x = self.b1(x)
            x = self.a1(x)
            x = self.blocks(x)
            x = self.p1(x)
            y = self.f1(x)
            return y
    
    
    model = ResNet18([2, 2, 2, 2])
    
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                  metrics=['sparse_categorical_accuracy'])
    
    checkpoint_save_path = "./checkpoint/ResNet18.ckpt"
    if os.path.exists(checkpoint_save_path + '.index'):
        print('-------------load the model-----------------')
        model.load_weights(checkpoint_save_path)
    
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                     save_weights_only=True,
                                                     save_best_only=True)
    
    history = model.fit(x_train, y_train, batch_size=1024, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                        callbacks=[cp_callback])
    model.summary()
    
    # print(model.trainable_variables)
    file = open('./weights.txt', 'w')
    for v in model.trainable_variables:
        file.write(str(v.name) + '
    ')
        file.write(str(v.shape) + '
    ')
        file.write(str(v.numpy()) + '
    ')
    file.close()
    
    ###############################################    show   ###############################################
    
    # 显示训练集和验证集的acc和loss曲线
    acc = history.history['sparse_categorical_accuracy']
    val_acc = history.history['val_sparse_categorical_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    
    plt.subplot(1, 2, 1)
    plt.plot(acc, label='Training Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.show()

    注意理解resnet网络结构

  • 相关阅读:
    进程间通信的方式——信号、管道、消息队列、共享内存
    exit()与_exit()的区别(转)
    [Google Codejam] Round 1A 2016
    使用shell脚本自定义实现选择登录ssh
    PHP的反射机制【转载】
    PHP set_error_handler()函数的使用【转载】
    PHP错误异常处理详解【转载】
    php的memcache和memcached扩展区别【转载】
    .htaccess重写URL讲解
    PHP数据库扩展mysqli的函数试题
  • 原文地址:https://www.cnblogs.com/python2/p/13599195.html
Copyright © 2011-2022 走看看