zoukankan      html  css  js  c++  java
  • 【Python】keras卷积神经网络识别mnist

    卷积神经网络的结构我随意设了一个。

    结构大概是下面这个样子:

    代码如下:

    import numpy as np
    from keras.preprocessing import image
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Flatten, Activation
    from keras.layers import Conv2D, MaxPooling2D
    
    # 从文件夹图像与标签文件载入数据
    def create_x(filenum, file_dir):
        train_x = []
        for i in range(filenum):
            img = image.load_img(file_dir + str(i) + ".bmp", target_size=(28, 28))
            img = img.convert('L')
            x = image.img_to_array(img)
            train_x.append(x)
        train_x = np.array(train_x)
        train_x = train_x.astype('float32')
        train_x /= 255
        return train_x
    
    
    def create_y(classes, filename):
        train_y = []
        file = open(filename, "r")
        for line in file.readlines():
            tmp = []
            for j in range(classes):
                if j == int(line):
                    tmp.append(1)
                else:
                    tmp.append(0)
            train_y.append(tmp)
        file.close()
        train_y = np.array(train_y).astype('float32')
        return train_y
    
    classes = 10
    X_train = create_x(55000, './train/')
    X_test = create_x(10000, './test/')
    
    Y_train = create_y(classes, 'train.txt')
    Y_test = create_y(classes, 'test.txt')
    
    # 从网络下载的数据集直接解析数据
    '''
    from tensorflow.examples.tutorials.mnist import input_data
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    X_train, Y_train = mnist.train.images, mnist.train.labels
    X_test, Y_test = mnist.test.images, mnist.test.labels
    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')
    '''
    model = Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(Dropout(0.25))
    
    model.add(Flatten())
    model.add(Dense(81, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(10))
    model.add(Activation('softmax'))
    model.summary()
    
    model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
    history = model.fit(X_train, Y_train, batch_size=500, epochs=10, verbose=1, validation_data=(X_test, Y_test))
    score = model.evaluate(X_test, Y_test, verbose=0)
    
    test_result = model.predict(X_test)
    result = np.argmax(test_result, axis=1)
    
    print(result)
    print('Test score:', score[0])
    print('Test accuracy:', score[1])

    最终在测试集上识别率在99%左右。

    相关测试数据可以在这里下载到。

  • 相关阅读:
    mysql 应用 持续更新2 转载
    sql server 用触发器记录增删改操作(转载)
    mysql 应用 持续更新
    oracle 常用指令(持续更新中....)
    转载-Oracle 数据库导入导出 dmp文件
    Web Service 服务无法连接Oracle数据库
    关于jquery获取服务器端xml数据
    Navicat Premium 自动备份mysql和sqlserver
    浅谈如何更好的打开和关闭ADO.NET连接池
    JSON 的优点
  • 原文地址:https://www.cnblogs.com/tiandsp/p/9638876.html
Copyright © 2011-2022 走看看