mnist数据的下载、读取部分请参见:DNN识别mnist手写数字
为了使读取到的图片数据能输入CNN,需要为图片数据增加channel维度
train_x = np.expand_dims(train_x,axis=-1)
test_x = np.expand_dims(test_x,axis=-1)
查看增维后数据的维度
print(train_x.shape)
print(test_x.shape)
搭建CNN并训练
drop_rate = 0.01
model = keras.Sequential()
model.add(layers.Conv2D(64,(3,3),activation='relu',input_shape=(28,28,1)))
model.add(layers.MaxPooling2D())
model.add(layers.Flatten())
model.add(layers.Dense(200,activation='relu'))
model.add(layers.Dropout(drop_rate))
model.add(layers.Dense(10,activation='softmax'))
adam = keras.optimizers.Adam(lr=0.001)
model.compile(optimizer=adam,loss='sparse_categorical_crossentropy',metrics=['acc'])
model.fit(train_x,train_y,epochs=10,batch_size=512)
经过10轮训练后,CNN在训练集上的loss和准确率如下
CNN在测试集上的loss和准确率如下
model.evaluate(test_x,test_y)