zoukankan      html  css  js  c++  java
  • Keras cnn 手写数字识别示例

    #基于mnist数据集的手写数字识别

    #构造了cnn网络拟合识别函数,前两层为卷积层,第三层为池化层,第四层为Flatten层,最后两层为全连接层

    #基于Keras 2.1.1 Tensorflow 1.4.0

    代码:

     1 from __future__ import print_function
     2 import numpy as np
     3 np.random.seed(1337)
     4 from keras.datasets import mnist
     5 from keras.models import Sequential
     6 from keras.layers import Dense, Dropout, Activation, Flatten
     7 from keras.layers import Convolution2D, MaxPooling2D
     8 from keras.utils import np_utils
     9 from keras import backend as K
    10 
    11 batch_size = 128
    12 nb_classes = 10
    13 nb_epoch = 12
    14 
    15 img_rows, img_cols = 28, 28
    16 nb_filters = 32
    17 pool_size = (2,2)
    18 kernel_size = (3,3)
    19 (X_train, y_train), (X_test, y_test) = mnist.load_data()
    20 
    21 X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
    22 X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
    23 input_shape = (img_rows, img_cols, 1)
    24 X_train = X_train.astype('float32')
    25 X_test = X_test.astype('float32')
    26 X_train /= 255
    27 X_test /= 255
    28 
    29 Y_train = np_utils.to_categorical(y_train, nb_classes)
    30 Y_test = np_utils.to_categorical(y_test, nb_classes)
    31 # 建立序贯模型
    32 model = Sequential()
    33 
    34 model.add(Convolution2D(nb_filters, kernel_size[0] ,kernel_size[1],border_mode='valid',input_shape=input_shape))
    35 model.add(Activation('relu'))
    36 
    37 
    38 model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1]))
    39 model.add(Activation('relu'))
    40 model.add(MaxPooling2D(pool_size=pool_size))
    41 model.add(Dropout(0.25))
    42 model.add(Flatten())
    43 model.add(Dense(128))
    44 model.add(Activation('relu'))
    45 model.add(Dropout(0.5))
    46 model.add(Dense(nb_classes))
    47 model.add(Activation('softmax'))
    48 
    49 
    50 model.summary()
    51 model.compile(loss='categorical_crossentropy',optimizer='adadelta',metrics=['accuracy'])
    52 model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch,verbose=1, validation_data=(X_test, Y_test))
    53 
    54 score = model.evaluate(X_test, Y_test, verbose=0)
    55 print('Test score:', score[0])
    56 print('Test accuracy:', score[1])
  • 相关阅读:
    将excel中的sheet1导入到sqlserver中
    .net中 Timer定时器
    Exception异常处理机制
    算法
    八、上网行为管理
    获取网站路径绝对路径的方法汇总
    Window逆向基础之逆向工程介绍
    Java Web代码审计流程与漏洞函数
    创建一个Java Web项目,获取POST数据并显示
    七、虚拟专用网
  • 原文地址:https://www.cnblogs.com/cnXuYang/p/8995927.html
Copyright © 2011-2022 走看看