zoukankan      html  css  js  c++  java
  • 用Kersa搭建神经网络【MNIST手写数据集】

    MNIST手写数据集的识别算得上是深度学习的”hello world“了,所以想要入门必须得掌握。新手入门可以考虑使用Keras框架达到快速实现的目的。

    完整代码如下:

    # 1. 导入库和模块
    from keras.models import Sequential
    from keras.layers import Conv2D, MaxPool2D
    from keras.layers import Dense, Flatten
    from keras.utils import to_categorical
    
    # 2. 加载数据
    from keras.datasets import mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    
    # 3. 数据预处理
    img_x, img_y = 28, 28
    x_train = x_train.reshape(x_train.shape[0], img_x, img_y, 1)
    x_test = x_test.reshape(x_test.shape[0], img_x, img_y, 1)
    #数据标准化
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    #一位有效编码
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)
    
    # 4. 定义模型结构
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(5,5), activation='relu', input_shape=(img_x, img_y, 1)))
    model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
    model.add(Conv2D(64, kernel_size=(5,5), activation='relu'))
    model.add(MaxPool2D(pool_size=(2,2), strides=(2,2)))
    model.add(Flatten())
    model.add(Dense(1000, activation='relu'))
    model.add(Dense(10, activation='softmax'))
    
    # 5. 编译,声明损失函数和优化器
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    
    # 6. 训练
    model.fit(x_train, y_train, batch_size=128, epochs=10)
    
    # 7. 评估模型
    score = model.evaluate(x_test, y_test)
    print('acc', score[1])

    运行结果如下:

     

    可以看出准确率达到了99%,说明神经网络在图像识别上具有巨大的优势。

  • 相关阅读:
    数据库访问表的问题
    UVA 10943全加和(规律)
    POJ 2594 最小路径覆盖 + 传递闭包
    phonegap入门7 capture.captureVideo 录像
    第二部分 Linux Shell高级编程技巧——第二章 Shell工具
    C#写的光模块烧写软件
    关于java的++和操作符,你真的搞明白了吗?
    MFCATL IDispatch调度接口
    c/c++函数调用约定
    HDOJ 2955 Robberies (0/1背包)
  • 原文地址:https://www.cnblogs.com/darklights/p/10385366.html
Copyright © 2011-2022 走看看