zoukankan      html  css  js  c++  java
  • tensorflow 2.0 学习 (十一)卷积神经网络 (一) MNIST数据集训练与预测 LeNet-5网络

    网络结构如下:

     代码如下:

     1 # encoding: utf-8
     2 
     3 import tensorflow as tf
     4 from tensorflow import keras
     5 from tensorflow.keras import layers, Sequential, losses, optimizers, datasets
     6 import matplotlib.pyplot as plt
     7 
     8 Epoch = 30
     9 path = r'G:2019pythonmnist.npz'
    10 (x, y), (x_val, y_val) = tf.keras.datasets.mnist.load_data(path)  # 60000 and 10000
    11 print('datasets:', x.shape, y.shape, x.min(), x.max())
    12 
    13 x = tf.convert_to_tensor(x, dtype = tf.float32)  #/255.    #0:1  ;   -1:1(不适合训练,准确度不高)
    14 # x = tf.reshape(x, [-1, 28*28])
    15 y = tf.convert_to_tensor(y, dtype=tf.int32)
    16 # y = tf.one_hot(y, depth=10)
    17 #将60000组训练数据切分为600组,每组100个数据
    18 train_db = tf.data.Dataset.from_tensor_slices((x, y))
    19 train_db = train_db.shuffle(60000)      #尽量与样本空间一样大
    20 train_db = train_db.batch(100)          #128
    21 
    22 x_val = tf.cast(x_val, dtype=tf.float32)
    23 y_val = tf.cast(y_val, dtype=tf.int32)
    24 test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    25 test_db = test_db.shuffle(10000)
    26 test_db = test_db.batch(100)        #128
    27 
    28 network = Sequential([
    29     layers.Conv2D(6, kernel_size=3, strides=1),  # 6个卷积核
    30     layers.MaxPooling2D(pool_size=2, strides=2),  # 池化层,高宽各减半
    31     layers.ReLU(),
    32     layers.Conv2D(16, kernel_size=3, strides=1),  # 16个卷积核
    33     layers.MaxPooling2D(pool_size=2, strides=2),  # 池化层,高宽各减半
    34     layers.ReLU(),
    35     layers.Flatten(),
    36 
    37     layers.Dense(120, activation='relu'),
    38     layers.Dense(84, activation='relu'),
    39     layers.Dense(10)
    40 ])
    41 network.build(input_shape=(4, 28, 28, 1))
    42 network.summary()
    43 optimizer = tf.keras.optimizers.RMSprop(0.001)              # 创建优化器,指定学习率
    44 criteon = losses.CategoricalCrossentropy(from_logits=True)
    45 
    46 # 保存训练和测试过程中的误差情况
    47 train_tot_loss = []
    48 test_tot_loss = []
    49 
    50 
    51 for step in range(Epoch):
    52     cor, tot = 0, 0
    53     for x, y in train_db:
    54         with tf.GradientTape() as tape:  # 构建梯度环境
    55             # 插入通道维度 [None,28,28] -> [None,28,28,1]
    56             x = tf.expand_dims(x, axis=3)
    57             out = network(x)
    58             y_true = tf.one_hot(y, 10)
    59             loss =criteon(y_true, out)
    60 
    61             out_train = tf.argmax(out, axis=-1)
    62             y_train = tf.cast(y, tf.int64)
    63             cor += float(tf.reduce_sum(tf.cast(tf.equal(y_train, out_train), dtype=tf.float32)))
    64             tot += x.shape[0]
    65 
    66             grads = tape.gradient(loss, network.trainable_variables)
    67             optimizer.apply_gradients(zip(grads, network.trainable_variables))
    68     print('After %d Epoch' % step)
    69     print('training acc is ', cor/tot)
    70     train_tot_loss.append(cor/tot)
    71 
    72     correct, total = 0, 0
    73     for x, y in test_db:
    74         x = tf.expand_dims(x, axis=3)
    75         out = network(x)
    76         pred = tf.argmax(out, axis=-1)
    77         y = tf.cast(y, tf.int64)
    78         correct += float(tf.reduce_sum(tf.cast(tf.equal(y, pred), dtype=tf.float32)))
    79         total += x.shape[0]
    80     print('testing acc is : ', correct/total)
    81     test_tot_loss.append(correct/total)
    82 
    83 
    84 plt.figure()
    85 plt.plot(train_tot_loss, 'b', label='train')
    86 plt.plot(test_tot_loss, 'r', label='test')
    87 plt.xlabel('Epoch')
    88 plt.ylabel('ACC')
    89 plt.legend()
    90 plt.savefig('exam8.2_train_test_CNN1.png')
    91 plt.show()

    训练和测试结果如下:

    下次更新CIFAR10数据集与改进VGG13网络

  • 相关阅读:
    PAT 1012 数字分类
    PAT 1046 划拳
    PAT 1021 个位数统计
    PAT 1003 我要通过!
    PAT 1031 查验身份证
    安装swoole
    PHP yield 分析,以及协程的实现,超详细版(上)
    PHP性能优化利器:生成器 yield理解
    swoole深入学习 8. 协程 转
    swoole| swoole 协程初体验 转
  • 原文地址:https://www.cnblogs.com/heze/p/12248251.html
Copyright © 2011-2022 走看看