zoukankan      html  css  js  c++  java
  • 莫烦大大keras学习Mnist识别(3)-----CNN

    一、步骤:

    1. 导入模块以及读取数据

    2. 数据预处理

    3. 构建模型

    4. 编译模型

    5. 训练模型

    6. 测试

    二、代码:

    1. 导入模块以及读取数据

    #导包
    import numpy as np
    np.random.seed(1337)
    # from keras.datasets import mnist
    from keras.utils import np_utils # 主要采用这个模块下的to_categorical函数,将该函数转成one_hot向量
    from keras.models import Sequential #keras的模型模块
    from keras.layers import Dense , Activation , Convolution2D, MaxPooling2D, Flatten  #keras的层模块
    from keras.optimizers import Adam #keras的优化器
    
    
    
    #读取数据,因为本地已经下载好数据在绝对路径:E:jupyterTensorFlowMNIST_data下,直接采用TensorFlow来读取
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets('E:jupyterTensorFlowMNIST_data',one_hot = True)
    
    X_train = mnist.train.images
    Y_train = mnist.train.labels
    X_test = mnist.test.images
    Y_test = mnist.test.labels

    2、数据预处理

    x原本的shape为(55000,784),55000表示样本数量,784表示一个图像样本拉成一个向量的大小,故要将其转成28*28这种长×宽的形式。(-1,1,28,28)中的-1是之后batch_size的大小,即一次取batch大小的样本来训练,1,28,28表示高为1,长为28,宽为28。

    #数据预处理
    X_train = X_train.reshape(-1,1,28,28)
    X_test = X_test.reshape(-1,1,28,28)
    y_train = np_utils.to_categorical(Y_train,num_classes = 10)#to_categorical将标签转化成ont-hot
    y_test = np_utils.to_categorical(Y_test,num_classes = 10)

     

    3、构建模型

    2个卷积层【包括卷积+激活relu+最大池化】+2个全连接层

    #模型构建
    model = Sequential()  #建立一个序列模型
    
    #在这个模型首层添加一个卷积层,一个卷积过滤器大小为5*5,32个过滤器,采用的padding模式是same,即通过补0使输入输出大小一下。首层要加一个输入大小(1,28,28)
    model.add(Convolution2D(
        nb_filter = 32,
        nb_row = 5,
        nb_col = 5,
        border_mode = 'same',
        input_shape = (1,28,28)
    ))
    
    #接着加一个激活层
    model.add(Activation('relu'))
    
    #接着加一个最大池化层,pool大小为(2,2),strides步长长移动2,宽移动2。padding采用same模式
    model.add(MaxPooling2D(
        pool_size = (2,2),
        strides = (2,2),
        border_mode = 'same',
    ))
    
    #卷积层2
    model.add(Convolution2D(64,5,5,border_mode = 'same'))
    
    #激活层2
    model.add(Activation('relu'))
    
    #池化层2
    model.add(MaxPooling2D(pool_size = (2,2),border_mode = 'same'))
    
    #进行全连接之前将矩阵展开成一个长向量
    model.add(Flatten())
    
    #全连接层1,大小有1024个参数
    model.add(Dense(1024))
    
    #激活层
    model.add(Activation('relu'))
    
    #全连接层2,大小为10
    model.add(Dense(10))
    
    #输出层加一个softmax处理
    model.add(Activation('softmax'))

    4、编译模型:model.compile

    采用model.compile来编译,函数内参数说明优化器optimizer、损失函数loss、评价标准metrics。

    #编译模型
    adam = Adam(lr = 1e-4)
    model.compile(optimizer=adam,
                 loss = 'categorical_crossentropy',
                 metrics = ['accuracy'])

    5、训练模型:model.fit

    类似sklearn中的形式

    model.fit(X_train,Y_train,nb_epoch = 20,batch_size = 32)

    6、测试:model.evaluate

    输出测试的损失和准确度

    loss , acc = model.evaluate(X_test,y_test)
  • 相关阅读:
    mysql报错:java.sql.SQLException: The server time zone value 'Öйú±ê׼ʱ¼ä' is unrecognized or represents more than one time zone.
    MD5登陆密码的生成
    15. 3Sum、16. 3Sum Closest和18. 4Sum
    11. Container With Most Water
    8. String to Integer (atoi)
    6. ZigZag Conversion
    5. Longest Palindromic Substring
    几种非线性激活函数介绍
    AI初探1
    AI初探
  • 原文地址:https://www.cnblogs.com/Lee-yl/p/10121057.html
Copyright © 2011-2022 走看看