zoukankan      html  css  js  c++  java
  • 基于手写数字识别数据集的机器学习方法对比研究

    基于手写数字识别数据集的机器学习方法对比研究

    摘要

    研究意义:统计机器学习和深度学习都已被广泛地应用。

    主流研究方法:在相同的数据集上进行对比实验。

    前人研究存在的问题:在检索范围内,没有发现统计学习方法与深度学习方法对比的工作。

    我们的解决手段:本文在手写数字识别数据集(MNIST)上,对比了主流的统计机器学习方法和深度学习方法的表现。

    我们解决的还不错:通过实验证明了深度学习在 MNIST 数据集上的效果更好,准确率为97.50%;统计机器学习方法(SVM)准确率为93.71%。

    Keywords: 手写数字识别, MNIST, DNN, SVM, 统计机器学习, 深度学习

    实验

    实验设置

    Epoch : 10

    Train Data Sample : 60000

    Test Data Sample : 10000

    Image Shape : (28, 28, 1)

    实验结果

    预测性能

    方法 Acc on Train Acc on Test Paramters
    DNN 0.9950 0.9808 1,238,730
    CNN+MaxPooling 0.9906 0.9742 1,332,810
    kernel approximation + LinearSVC 0.9378 0.9371 N/A
    SVC 0.9899 0.9792 N/A

    执行效率

    CPU 80线程,128GB内存,固态硬盘

    方法 Training and Inference
    DNN 0m 38.849s
    CNN+MaxPooling 11m 19.786s
    kernel approximation + LinearSVC 0m 20.889s
    SVC 10m 54.445s

    结论

    1.深度学习方法在足量的数据上,可以取得比统计学习方法更高的准确率;

    2.CNN+MaxPooling方法在当前的“实验设置”下,过拟合了;

    3.在当前的“实验设置”下,DNN方法的效果一致好于CNN+MaxPooling方法;

    4.自带核函数的SVM预测效果,好于近似核函数和线性SVM的组合方法;

    5.自带核函数的SVM,训练时间和推断时间都远高于近似核函数和线性SVM的组合方法,高于DNN,略低于CNN;

    代码

    DNN

    # encoder=utf-8
    
    from tensorflow import keras
    from tensorflow.keras import Model, layers
    from tensorflow.keras.utils import to_categorical
    import numpy as np
    
    
    # Load Dataset
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    print(x_train.shape)
    
    # Reshape the data
    x_train = np.reshape(x_train, (len(x_train), 28 * 28)) / 255.0
    x_test = np.reshape(x_test, (len(x_test), 28 * 28)) / 255.0
    print(x_train.shape)
    
    # categorical labels
    y_train = to_categorical(y_train, num_classes=10)
    y_test = to_categorical(y_test, num_classes=10)
    print(y_train.shape)
    
    # Define and build the model
    input_img = layers.Input(shape=28*28)
    x = layers.Dense(28*28, activation='relu')(input_img)
    x = layers.Dense(28*28, activation='sigmoid')(x)
    x = layers.Dense(10, activation='softmax')(x)
    model = Model(input_img, x)
    model.summary()
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics='acc'
    )
    model.fit(
        x=x_train,
        y=y_train,
        batch_size=128,
        epochs=10
    )
    loss, metric = model.evaluate(x=x_test, y=y_test, batch_size=128, )
    print("cross entropy is %.4f, accuracy is %.4f" % (loss, metric))
    

    CNN + MaxPooling

    # encoder=utf-8
    
    from tensorflow import keras
    from tensorflow.keras import Model, layers
    from tensorflow.keras.utils import to_categorical
    
    
    # Load Dataset
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    print(x_train.shape)
    
    # normalize the data
    x_train = x_train / 255.0
    x_test = x_test / 255.0
    
    # categorical labels
    y_train = to_categorical(y_train, num_classes=10)
    y_test = to_categorical(y_test, num_classes=10)
    print(y_train.shape)
    
    # Define and build the model
    input_img = layers.Input(shape=(28, 28, 1))
    x = layers.Conv2D(28*28, (3, 3))(input_img)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Flatten()(x)
    x = layers.Dense(10, activation='softmax')(x)
    model = Model(input_img, x)
    model.summary()
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics='acc'
    )
    model.fit(
        x=x_train,
        y=y_train,
        batch_size=128,
        epochs=10
    )
    loss, metric = model.evaluate(x=x_test, y=y_test, batch_size=128, )
    print("cross entropy is %.4f, accuracy is %.4f" % (loss, metric))
    

    Kernel approximation + LinearSVM

    # encoder=utf-8
    
    from tensorflow import keras
    import numpy as np
    from sklearn.kernel_approximation import Nystroem
    from sklearn.svm import LinearSVC
    from sklearn.metrics import accuracy_score
    
    
    # Load Dataset
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    print(x_train.shape)
    
    # Reshape the data
    x_train = np.reshape(x_train, (len(x_train), 28 * 28)) / 255.0
    x_test = np.reshape(x_test, (len(x_test), 28 * 28)) / 255.0
    print(x_train.shape)
    print(y_train.shape)
    
    # Define and build the kernel mapping
    x = np.concatenate((x_train, x_test))
    print(x.shape)
    
    # SVC is too slow to practice, hence we split the SVC into
    # approximating kernel map (sklearn.kernel_approximation.Nystroem)
    # and linear SVM (sklearn.svm.LinearSVC)
    feature_map_nystroem = Nystroem(n_components=28*28)
    feature_map_nystroem.fit(x)
    x = feature_map_nystroem.transform(x)
    
    x_train = x[:60000]
    x_test = x[60000:]
    print(x_train.shape)
    print(x_test.shape)
    
    cls = LinearSVC()
    cls.fit(x_train, y_train)
    y_pred = cls.predict(x_train)
    ret = accuracy_score(y_train, y_pred)
    print(ret)
    
    y_pred = cls.predict(x_test)
    ret = accuracy_score(y_test, y_pred)
    print(ret)
    

    SVC

    # encoder=utf-8
    
    from tensorflow import keras
    import numpy as np
    from sklearn.svm import SVC
    from sklearn.metrics import accuracy_score
    
    
    # Load Dataset
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    print(x_train.shape)
    
    # Reshape the data
    x_train = np.reshape(x_train, (len(x_train), 28 * 28)) / 255.0
    x_test = np.reshape(x_test, (len(x_test), 28 * 28)) / 255.0
    print(x_train.shape)
    print(y_train.shape)
    
    
    cls = SVC()
    cls.fit(x_train, y_train)
    y_pred = cls.predict(x_train)
    ret = accuracy_score(y_train, y_pred)
    print(ret)
    
    y_pred = cls.predict(x_test)
    ret = accuracy_score(y_test, y_pred)
    print(ret)
    
    智慧在街市上呼喊,在宽阔处发声。
  • 相关阅读:
    mysql面试题
    Excel下载打不开
    Linux安装jdk1.8和配置环境变量
    Linux压缩、解压文件
    Linux常用命令1
    VMware下载安装及CentOS7下载安装
    ueditor的简单配置和使用
    linux的tomcat服务器上部署项目的方法
    TortoiseSVN客户端的使用说明
    CentOS 6.5系统上安装SVN服务器
  • 原文地址:https://www.cnblogs.com/fengyubo/p/15334661.html
Copyright © 2011-2022 走看看