zoukankan      html  css  js  c++  java
  • CNN识别mnist手写数字

    mnist数据的下载、读取部分请参见:DNN识别mnist手写数字

    为了使读取到的图片数据能输入CNN,需要为图片数据增加channel维度

    train_x = np.expand_dims(train_x,axis=-1)
    test_x = np.expand_dims(test_x,axis=-1)
    

    查看增维后数据的维度

    print(train_x.shape)
    print(test_x.shape)
    

    搭建CNN并训练

    drop_rate = 0.01
    model = keras.Sequential()
    model.add(layers.Conv2D(64,(3,3),activation='relu',input_shape=(28,28,1)))
    model.add(layers.MaxPooling2D())
    model.add(layers.Flatten())
    model.add(layers.Dense(200,activation='relu'))
    model.add(layers.Dropout(drop_rate))
    model.add(layers.Dense(10,activation='softmax'))
    adam = keras.optimizers.Adam(lr=0.001)
    model.compile(optimizer=adam,loss='sparse_categorical_crossentropy',metrics=['acc'])
    model.fit(train_x,train_y,epochs=10,batch_size=512)
    

    经过10轮训练后,CNN在训练集上的loss和准确率如下

    CNN在测试集上的loss和准确率如下

    model.evaluate(test_x,test_y)
    

  • 相关阅读:
    1206 冲刺三
    1130持续更新
    1128项目跟进
    冲刺一1123(总结)
    冲刺一
    1117 新冲刺
    0621 第三次冲刺及课程设计
    0621回顾和总结
    实验四主存空间的分配和回收
    学习进度条
  • 原文地址:https://www.cnblogs.com/bill-h/p/13907572.html
Copyright © 2011-2022 走看看