zoukankan      html  css  js  c++  java
  • DNN识别mnist手写数字

    mnist数据下载地址:
    链接:https://pan.baidu.com/s/1GD2hI8Wf4oUR-V2NysYorw
    提取码:sg3f

    导库

    import numpy as np
    import matplotlib.pyplot as plt
    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
    

    读取mnist数据

    import numpy as np
    path='./mnist.npz'
    f = np.load(path)
    train_x, train_y = f['x_train'], f['y_train']    # 训练集
    test_x, test_y = f['x_test'], f['y_test']    # 测试集
    f.close()
    

    查看数据格式

    print(train_x.shape)
    print(train_y.shape)
    print(test_x.shape)
    print(test_y.shape)
    

    将数据以图片形式输出

    plt.imshow(train_x[10000])
    

    将数据格式改为DNN可接收的一维格式

    train_x = train_x.reshape((60000,28*28),order='C')    # 将二维的图片展开为一维的数据(训练集)  
    test_x = test_x.reshape((10000,28*28),order='C')    # 将二维的图片展开为一维的数据(测试集)
    

    搭建DNN并训练

    model = keras.Sequential()
    model.add(layers.Dense(100,activation='relu',input_dim=28*28))
    model.add(layers.Dense(10,activation='softmax'))
    adam = keras.optimizers.Adam(lr=0.01)
    model.compile(optimizer=adam,loss='sparse_categorical_crossentropy',metrics=['acc'])
    model.fit(train_x,train_y,epochs=50,batch_size=512)
    

    经过50轮训练后,DNN在训练集上的loss和准确率如下

    DNN在测试集上的loss和准确率如下

    model.evaluate(test_x,test_y)
    

    完整的代码如下

    import numpy as np
    import matplotlib.pyplot as plt
    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
    
    path='./mnist.npz'
    f = np.load(path)
    train_x, train_y = f['x_train'], f['y_train']    # 训练集
    test_x, test_y = f['x_test'], f['y_test']    # 测试集
    f.close()
    
    print(train_x.shape)
    print(train_y.shape)
    print(test_x.shape)
    print(test_y.shape)
    
    plt.imshow(train_x[10000])
    
    train_x = train_x.reshape((60000,28*28),order='C')    # 将二维的图片展开为一维的数据(训练集)  
    test_x = test_x.reshape((10000,28*28),order='C')    # 将二维的图片展开为一维的数据(测试集)
    
    model = keras.Sequential()
    model.add(layers.Dense(100,activation='relu',input_dim=28*28))
    model.add(layers.Dense(10,activation='softmax'))
    model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])
    model.fit(train_x,train_y,epochs=50,batch_size=512)
    
    model.evaluate(test_x,test_y)
    
  • 相关阅读:
    Java8中利用stream对map集合进行过滤的方法
    安装数据库MySQL,启动时报错 服务没有响应控制功能 的解决办法
    mysql 安装时 失败,提示 因为计算机中丢失 msvcp140.dll
    复习一下数学排列组合公式的原理
    java如何进行排列组合运算
    Redis 分布式锁:使用Set+lua替代 setnx
    深入详解Go的channel底层实现原理【图解】
    MYSQL MVCC实现原理详解
    聚簇索引和非聚簇索引,全在这!!!
    深度解密Go语言之 map
  • 原文地址:https://www.cnblogs.com/bill-h/p/13906166.html
Copyright © 2011-2022 走看看