zoukankan      html  css  js  c++  java
  • 【Keras案例学习】 CNN做手写字符分类(mnist_cnn )

    from __future__ import print_function
    import numpy as np
    np.random.seed(1337)
    
    from keras.datasets import mnist
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Activation, Flatten
    from keras.layers import Convolution2D, MaxPooling2D
    from keras.utils import np_utils
    from keras import backend as K
    
    batch_size = 128
    nb_classes = 10
    nb_epoch = 12
    
    # 输入图像的维度,此处是mnist图像,因此是28*28
    img_rows, img_cols = 28, 28
    # 卷积层中使用的卷积核的个数
    nb_filters = 32
    # 池化层操作的范围
    pool_size = (2,2)
    # 卷积核的大小
    kernel_size = (3,3)
    # keras中的mnist数据集已经被划分成了60,000个训练集,10,000个测试集的形式,按以下格式调用即可
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    
    # 后端使用tensorflow时,即tf模式下,
    # 会将100张RGB三通道的16*32彩色图表示为(100,16,32,3),
    # 第一个维度是样本维,表示样本的数目,
    # 第二和第三个维度是高和宽,
    # 最后一个维度是通道维,表示颜色通道数
    X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
    X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)
    
    # 将X_train, X_test的数据格式转为float32
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    # 归一化
    X_train /= 255
    X_test /= 255
    # 打印出相关信息
    print('X_train shape:', X_train.shape)
    print(X_train.shape[0], 'train samples')
    print(X_test.shape[0], 'test samples')
    
    X_train shape: (60000, 28, 28, 1)
    60000 train samples
    10000 test samples
    
    # 将类别向量(从0到nb_classes的整数向量)映射为二值类别矩阵,
    # 相当于将向量用one-hot重新编码
    Y_train = np_utils.to_categorical(y_train, nb_classes)
    Y_test = np_utils.to_categorical(y_test, nb_classes)
    
    # 建立序贯模型
    model = Sequential()
    
    # 卷积层,对二维输入进行滑动窗卷积
    # 当使用该层为第一层时,应提供input_shape参数,在tf模式中,通道维位于第三个位置
    # border_mode:边界模式,为"valid","same"或"full",即图像外的边缘点是补0
    # 还是补成相同像素,或者是补1
    model.add(Convolution2D(nb_filters, kernel_size[0] ,kernel_size[1],
                            border_mode='valid',
                            input_shape=input_shape))
    model.add(Activation('relu'))
    
    # 卷积层,激活函数是ReLu
    model.add(Convolution2D(nb_filters, kernel_size[0], kernel_size[1]))
    model.add(Activation('relu'))
    
    # 池化层,选用Maxpooling,给定pool_size,dropout比例为0.25
    model.add(MaxPooling2D(pool_size=pool_size))
    model.add(Dropout(0.25))
    
    # Flatten层,把多维输入进行一维化,常用在卷积层到全连接层的过渡
    model.add(Flatten())
    
    # 包含128个神经元的全连接层,激活函数为ReLu,dropout比例为0.5
    model.add(Dense(128))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    
    # 包含10个神经元的输出层,激活函数为Softmax
    model.add(Dense(nb_classes))
    model.add(Activation('softmax'))
    
    # 输出模型的参数信息
    model.summary()
    # 配置模型的学习过程
    model.compile(loss='categorical_crossentropy',
                  optimizer='adadelta',
                  metrics=['accuracy'])
    
    ____________________________________________________________________________________________________
    Layer (type)                     Output Shape          Param #     Connected to                     
    ====================================================================================================
    convolution2d_3 (Convolution2D)  (None, 26, 26, 32)    320         convolution2d_input_2[0][0]      
    ____________________________________________________________________________________________________
    activation_5 (Activation)        (None, 26, 26, 32)    0           convolution2d_3[0][0]            
    ____________________________________________________________________________________________________
    convolution2d_4 (Convolution2D)  (None, 24, 24, 32)    9248        activation_5[0][0]               
    ____________________________________________________________________________________________________
    activation_6 (Activation)        (None, 24, 24, 32)    0           convolution2d_4[0][0]            
    ____________________________________________________________________________________________________
    maxpooling2d_2 (MaxPooling2D)    (None, 12, 12, 32)    0           activation_6[0][0]               
    ____________________________________________________________________________________________________
    dropout_3 (Dropout)              (None, 12, 12, 32)    0           maxpooling2d_2[0][0]             
    ____________________________________________________________________________________________________
    flatten_2 (Flatten)              (None, 4608)          0           dropout_3[0][0]                  
    ____________________________________________________________________________________________________
    dense_3 (Dense)                  (None, 128)           589952      flatten_2[0][0]                  
    ____________________________________________________________________________________________________
    activation_7 (Activation)        (None, 128)           0           dense_3[0][0]                    
    ____________________________________________________________________________________________________
    dropout_4 (Dropout)              (None, 128)           0           activation_7[0][0]               
    ____________________________________________________________________________________________________
    dense_4 (Dense)                  (None, 10)            1290        dropout_4[0][0]                  
    ____________________________________________________________________________________________________
    activation_8 (Activation)        (None, 10)            0           dense_4[0][0]                    
    ====================================================================================================
    Total params: 600,810
    Trainable params: 600,810
    Non-trainable params: 0
    ____________________________________________________________________________________________________
    
    # 训练模型
    model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
              verbose=1, validation_data=(X_test, Y_test))
    
    # 按batch计算在某些输入数据上模型的误差
    score = model.evaluate(X_test, Y_test, verbose=0)
    
    Train on 60000 samples, validate on 10000 samples
    Epoch 1/12
    60000/60000 [==============================] - 18s - loss: 0.3675 - acc: 0.8886 - val_loss: 0.0877 - val_acc: 0.9722
    Epoch 2/12
    60000/60000 [==============================] - 13s - loss: 0.1346 - acc: 0.9598 - val_loss: 0.0623 - val_acc: 0.9802
    Epoch 3/12
    60000/60000 [==============================] - 13s - loss: 0.1039 - acc: 0.9691 - val_loss: 0.0527 - val_acc: 0.9837
    Epoch 4/12
    60000/60000 [==============================] - 13s - loss: 0.0887 - acc: 0.9736 - val_loss: 0.0462 - val_acc: 0.9849
    Epoch 5/12
    60000/60000 [==============================] - 13s - loss: 0.0778 - acc: 0.9763 - val_loss: 0.0420 - val_acc: 0.9860
    Epoch 6/12
    60000/60000 [==============================] - 13s - loss: 0.0698 - acc: 0.9794 - val_loss: 0.0383 - val_acc: 0.9871
    Epoch 7/12
    60000/60000 [==============================] - 14s - loss: 0.0659 - acc: 0.9802 - val_loss: 0.0374 - val_acc: 0.9868
    Epoch 8/12
    60000/60000 [==============================] - 14s - loss: 0.0616 - acc: 0.9818 - val_loss: 0.0385 - val_acc: 0.9877
    Epoch 9/12
    60000/60000 [==============================] - 14s - loss: 0.0563 - acc: 0.9829 - val_loss: 0.0338 - val_acc: 0.9881
    Epoch 10/12
    60000/60000 [==============================] - 14s - loss: 0.0531 - acc: 0.9845 - val_loss: 0.0320 - val_acc: 0.9889
    Epoch 11/12
    60000/60000 [==============================] - 13s - loss: 0.0498 - acc: 0.9855 - val_loss: 0.0323 - val_acc: 0.9890
    Epoch 12/12
    60000/60000 [==============================] - 14s - loss: 0.0479 - acc: 0.9852 - val_loss: 0.0329 - val_acc: 0.9892
    
    # 输出训练好的模型在测试集上的表现
    print('Test score:', score[0])
    print('Test accuracy:', score[1])
    
    Test score: 0.032927570413
    Test accuracy: 0.9892
  • 相关阅读:
    第二十七节(多线程、线程的创建和启动、生命周期、调度、控制、同步)
    第二十六节(对象流,File类)
    第二十五节(转换流,打印流)
    第二十四节(Java文件流,缓冲流)
    第二十三节(String,StringBuffer,基础类型对应的 8 个包装类,日期相关类、 Random 数字 ,Enum枚举)下
    【转】perl如何避免脚本在windows中闪一下就关闭
    计算机基础(二)
    计算机基础(一)
    04 数据结构
    03 逻辑与结构
  • 原文地址:https://www.cnblogs.com/surfzjy/p/6445437.html
Copyright © 2011-2022 走看看