zoukankan      html  css  js  c++  java
  • LeNet-5以及tensorflow2.1实现

    LeNet-5

    LeNet-5网络结构


    其中池化层均采用最大池化,每一层卷积层后使用的激活函数是sigmoid函数。
    这里补充一下padding的两种方式,一个是SAME(全0填充),另一个是VALID(不填充)。在LeNet-5中,卷积层一致采用padding='SAME'的方式进行填充,在池化层中采用padding='VALID'的方式填充。
    填充的方式不同,那么输出的图片的边长也是不同的。

    padding方式 输出图片边长
    SAME (frac{输入长}{步长})(向上取整)
    VALID (frac{输入长-核长+1}{步长})(向上取整)

    tensorflow实现LeNet-5

    接下来,我们用tensorflow(2.1版本)来搭建LeNet-5,实现fashion_mnist的图片分类。

    import tensorflow as tf
    from tensorflow import keras
    
    # 搭建LeNet网络
    net = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(filters=6,kernel_size=5,activation='sigmoid',input_shape=(28,28,1)),
        tf.keras.layers.MaxPool2D(pool_size=2,strides=2),
        tf.keras.layers.Conv2D(filters=16,kernel_size=5,activation='sigmoid'),
        tf.keras.layers.MaxPool2D(pool_size=2,strides=2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(120,activation='sigmoid'),
        tf.keras.layers.Dense(84,activation='sigmoid'),
        tf.keras.layers.Dense(10,activation='sigmoid')
    ])
    

    获取fashion_mnist的数据集

    fashion_mnist = tf.keras.datasets.fashion_mnist
    (train_images,train_labels),(test_images,test_labels) = fashion_mnist.load_data()
    
    train_images = tf.reshape(train_images,(train_images.shape[0],train_images.shape[1],train_images.shape[2],1))
    print(train_images.shape)
    test_images = tf.reshape(test_images,(test_images.shape[0],test_images.shape[1],test_images.shape[2],1))
    # 输出(60000,28,28,1)
    

    损失函数和训练算法采用交叉熵损失函数(cross entropy)和小批量随机梯度下降(SGD)

    optimizer = tf.keras.optimizers.SGD(learning_rate=0.9,momentum=0.0,nesterov=False)
    net.compile(optimizer=optimizer,
               loss='sparse_categorical_crossentropy',
               metrics=['accuracy'])
    
    net.fit(train_images,train_labels,epochs=5,validation_split=0.1)
    

    输出:

    net.evaluate(test_images,test_labels,verbose=2)
    

    输出:

  • 相关阅读:
    eclipse web项目转maven项目
    spark作业
    大数据学习——spark-steaming学习
    大数据学习——sparkSql对接hive
    大数据学习——sparkSql对接mysql
    大数据学习——sparkSql
    大数据学习——spark运营案例
    大数据学习——spark笔记
    大数据学习——sparkRDD
    python面试题
  • 原文地址:https://www.cnblogs.com/CuteyThyme/p/12741241.html
Copyright © 2011-2022 走看看