import tensorflow as tf (train_x,train_y),(test_x,test_y) = tf.keras.datasets.mnist.load_data()
划分验证集和测试集
valid_x = train_x[:5000] valid_y = train_y[:5000] train_x = train_x[5000:] train_y = train_y[5000:]
显示数据集
import matplotlib.pyplot as plt plt.figure(figsize=(12,5)) for i in range(10): plt.subplot(2,5,i+1) plt.imshow(train_x[i]) plt.xlabel(train_y[i]) plt.xticks([]) plt.yticks([]) plt.savefig('lenet-mnist') plt.show()
数据归一化
from sklearn.preprocessing import StandardScaler import numpy as np scaler = StandardScaler() train_x_scaled = scaler.fit_transform(train_x.astype(np.float32).reshape(-1,1)).reshape(-1,28,28) valid_x_scaled = scaler.fit_transform(valid_x.astype(np.float32).reshape(-1,1)).reshape(-1,28,28) test_x_scaled = scaler.fit_transform(test_x.astype(np.float32).reshape(-1,1)).reshape(-1,28,28)
train_x = tf.reshape(train_x_scaled,(train_x_scaled.shape[0],train_x_scaled.shape[1],train_x_scaled.shape[2],1)) valid_x = tf.reshape(valid_x_scaled,(valid_x_scaled.shape[0],valid_x_scaled.shape[1],valid_x_scaled.shape[2],1)) test_x = tf.reshape(test_x_scaled,(test_x_scaled.shape[0],test_x_scaled.shape[1],test_x_scaled.shape[2],1))
构建lenet网络
model = keras.models.Sequential() model.add(keras.layers.Conv2D(6,(5,5),activation='relu',input_shape = (28,28,1))) model.add(keras.layers.MaxPool2D(2,2)) model.add(keras.layers.Conv2D(16,(5,5),activation='relu')) model.add(keras.layers.MaxPool2D(2,2)) model.add(keras.layers.Flatten()) model.add(keras.layers.Dense(120,activation='relu')) model.add(keras.layers.Dense(84,activation='relu')) model.add(keras.layers.Dense(10,activation='sigmoid'))
编译网络
import os,sys lenet_logdir = os.path.join('lenet_logdir') out_model = os.path.join(lenet_logdir,'lenet_mnist.h5') callbacks = [keras.callbacks.TensorBoard(lenet_logdir), keras.callbacks.ModelCheckpoint(lenet_logdir,save_best_only=True)] model.compile(optimizer = 'sgd', loss = 'sparse_categorical_crossentropy', metrics = ['accuracy'])
训练网络
history = model.fit(train_x,train_y, epochs = 30, validation_data = (valid_x,valid_y), callbacks=callbacks)
可视化信息
import pandas as pd pd.DataFrame(history.history).plot(figsize = (10,4)) plt.grid(True) plt.gca().set_ylim(0,1) plt.gca().set_xlim(0,5) plt.show()
模型测试
model.evaluate(test_x,test_y)
10000/10000 [==============================] - 1s 56us/sample - loss: 0.0352 - accuracy: 0.9899
[0.03522909182251751, 0.9899]