zoukankan      html  css  js  c++  java
  • TensorFlow学习笔记 使用CPABD实现最简单的CNN模型

    import os
    from tensorflow.keras.datasets import mnist
    import tensorflow as tf
    from tensorflow.python.keras import Model
    from tensorflow.python.keras.datasets import cifar10
    from tensorflow.python.keras.layers import Flatten, Dense, Conv2D, BatchNormalization, AvgPool2D, Activation, MaxPool2D, \
        Dropout
    
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    
    x_train, x_test = x_train/255.0, x_test/255.0
    
    checkpoint_save_path = './checkpoint/model.ckpt'
    
    
    # 搭建模型类, 口诀:CBAPD,C卷积B批标准化A激活P池化D全连接
    class ConvModel(Model):
        def __init__(self):
            super(ConvModel, self).__init__()
            # filters: 卷积核个数  kernel_size:卷积核尺寸  strides:横纵向步长  padding:是否使用全零填充,same为是 activation:激活函数
            self.conv1 = Conv2D(filters=6, kernel_size=(5, 5), strides=(1, 1), padding='same', activation=None)
            # 在激活函数前,先进行一次批标准化,使得输入值更靠近0均值
            self.bn = BatchNormalization()
            # 激活函数
            self.activation = Activation('relu')
            # 池化,减少输入特征值
            self.pool = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')
            # Dropout防止过拟合
            self.dropout1 = Dropout(0.2)
    
            # 特征抽取完,拉直维度后通过全连接层输出
            self.flatten = Flatten()
            self.d1 = Dense(128, activation='relu')
            self.dropout2 = Dropout(0.2)
            self.d2 = Dense(10, activation='softmax')
    
        def call(self, x):
            x = self.conv1(x)
            x = self.bn(x)
            x = self.activation(x)
            x = self.pool(x)
            x = self.dropout1(x)
            x = self.flatten(x)
            x = self.d1(x)
            x = self.dropout2(x)
            y = self.d2(x)
            return y
    
    
    model = ConvModel()
    
    # 模型优化
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=tf.keras.losses.sparse_categorical_crossentropy,
                  metrics=['sparse_categorical_accuracy'])
    
    # callback保存模型
    model_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True,
                                                        save_best_only=True)
    
    # 曾经保存过,直接加载权重参数
    if os.path.exists(checkpoint_save_path + '.index'):
        model.load_weights(checkpoint_save_path)
    
    # 开始训练
    model.fit(x=x_train, y=y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), callbacks=[model_callback])
    
    # 结果总览
    model.summary()
  • 相关阅读:
    wget一个小技巧
    【iOS官方文档翻译】UICollectionView与UICollectionViewFlowLayout
    NSDate获取当前时区的时间
    怎样把一个字典的数据添加到另一个字典中?
    CoreLocation基本使用
    iOS开发--一步步教你彻底学会『iOS应用间相互跳转』
    Save Image to UserDefaults(用NSUserDefaults保存图片)
    SDWebImage源码解析
    获取cell或者cell中的控件在屏幕中的位置
    Git命令详解 123
  • 原文地址:https://www.cnblogs.com/yiduobaozhiblog1/p/15681807.html
Copyright © 2011-2022 走看看