zoukankan      html  css  js  c++  java
  • 【Python】keras使用Lenet5识别mnist

     原始论文中的网络结构如下图:

    keras生成的网络结构如下图:

    代码如下:

    import numpy as np
    from keras.preprocessing import image
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Flatten, Activation
    from keras.layers import Conv2D, MaxPooling2D
    from keras.utils.vis_utils import plot_model
    from keras.utils import np_utils
    
    
    # 从文件夹图像与标签文件载入数据
    def create_x(filenum, file_dir):
        train_x = []
        for i in range(filenum):
            img = image.load_img(file_dir + str(i) + ".bmp", target_size=(28, 28))
            img = img.convert('L')
            x = image.img_to_array(img)
            train_x.append(x)
        train_x = np.array(train_x)
        train_x = train_x.astype('float32')
        train_x /= 255
        return train_x
    
    
    def create_y(classes, filename):
        train_y = []
        file = open(filename, "r")
        for line in file.readlines():
            train_y.append(int(line))
        file.close()
        train_y = np.array(train_y).astype('float32')
        train_y = np_utils.to_categorical(train_y, classes)
        return train_y
    
    
    classes = 10
    
    X_train = create_x(55000, './train/')
    X_test = create_x(10000, './test/')
    
    Y_train = create_y(classes, 'train.txt')
    Y_test = create_y(classes, 'test.txt')
    
    # 从网络下载的数据集直接解析数据
    '''
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST/", one_hot=True)
    X_train, Y_train = mnist.train.images, mnist.train.labels
    X_test, Y_test = mnist.test.images, mnist.test.labels
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    print(X_train.shape, X_test.shape)
    X_train = X_train.reshape(X_train.shape[0], 28, 28, 1)
    X_test = X_test.reshape(X_test.shape[0], 28, 28, 1)
    print(X_train[0])
    '''
    model = Sequential()
    model.add(Conv2D(filters=6, kernel_size=(5, 5), padding='valid', input_shape=(28, 28, 1), activation='tanh')) #C1
    model.add(MaxPooling2D(pool_size=(2, 2)))    #S2
    model.add(Conv2D(filters=16, kernel_size=(5, 5), padding='valid', activation='tanh'))  #C3
    model.add(MaxPooling2D(pool_size=(2, 2)))    #S4
    model.add(Flatten())
    model.add(Dense(120, activation='tanh'))    #C5
    model.add(Dense(84, activation='tanh'))    #F6
    model.add(Dense(10, activation='softmax'))  #output
    model.summary()
    
    model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])
    history = model.fit(X_train, Y_train, batch_size=500, epochs=50, verbose=1, validation_data=(X_test, Y_test))
    score = model.evaluate(X_test, Y_test, verbose=0)
    
    test_result = model.predict(X_test)
    result = np.argmax(test_result, axis=1)
    
    print(result)
    print('Test score:', score[0])
    print('Test accuracy:', score[1])
    
    plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=False)

    50次迭代,识别率在97%左右:

     相关测试数据可以在这里下载到。

  • 相关阅读:
    Java中Vector和ArrayList的区别
    多线程
    集合框架
    5种运行时异常+1道面试题
    事务,视图,索引,备份和恢复
    MYSQL常用函数
    SQL数据库表字段明细导入导出
    SqlServer 命令方式备份与还原
    .NetCore IIS发布后PUT、DELETE请求错误405.0
    大数据中HBase的Java接口封装
  • 原文地址:https://www.cnblogs.com/tiandsp/p/9644698.html
Copyright © 2011-2022 走看看