zoukankan      html  css  js  c++  java
  • 15 手写数字识别-小数据集

    1.手写数字数据集

    • from sklearn.datasets import load_digits
    • digits = load_digits()
    digits = load_digits()
    X_data = digits.data.astype(np.float32)
    Y_data = digits.target.astype(np.float32).reshape(-1, 1)

    2.图片数据预处理

    • x:归一化MinMaxScaler()
    • y:独热编码OneHotEncoder()或to_categorical
    • 训练集测试集划分
    • 张量结构
    # 将属性缩放到一个指定的最大和最小值(通常是1-0之间)
    # x:归一化MinMaxScaler()
    scaler = MinMaxScaler()
    X_data = scaler.fit_transform(X_data)
    X = X_data.reshape(-1, 8, 8, 1)
    print("MinMaxScaler_trans_X_data:")
    print(X_data)
    # y:独热编码OneHotEncoder 张量结构todense
    # 进行oe-hot编码
    Y = OneHotEncoder().fit_transform(Y_data).todense()
    print("one-hot_Y:")
    print(Y)
    # 训练集测试集划分
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=0, stratify=Y)
    print('X_train.shape, X_test.shape, y_train.shape, y_test.shape:', X_train.shape, X_test.shape, y_train.shape, y_test.shape)

    结果:

     

     

     

    3.设计卷积神经网络结构

    • 绘制模型结构图,并说明设计依据。

    # 设计卷积神经网络结构
    # 建立模型
    model = Sequential()
    ks = [3, 3]  # 卷积核大小
    # 一层卷积,输入数据的shape要指定,其它层的数据shape框架会自动推导
    model.add(Conv2D(filters=16, kernel_size=ks, padding='same', input_shape=X_train.shape[1:], activation='relu'))
    # 池化层
    model.add(MaxPool2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    # 二层卷积
    model.add(Conv2D(filters=32, kernel_size=ks, padding='same', activation='relu'))
    # 池化层
    model.add(MaxPool2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    # 三层卷积
    model.add(Conv2D(filters=64, kernel_size=ks, padding='same', activation='relu'))
    # 四层卷积
    model.add(Conv2D(filters=128, kernel_size=ks, padding='same', activation='relu'))
    # 池化层
    model.add(MaxPool2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    # 平坦层
    model.add(Flatten())
    # 全连接层
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.25))
    # 激活函数
    model.add(Dense(10, activation='softmax'))
    model.summary()

     结果:

    4.模型训练

    # 设计卷积神经网络结构
    # 建立模型
    model = Sequential()
    ks = [3, 3]  # 卷积核大小
    # 一层卷积,输入数据的shape要指定,其它层的数据shape框架会自动推导
    model.add(Conv2D(filters=16, kernel_size=ks, padding='same', input_shape=X_train.shape[1:], activation='relu'))
    # 池化层
    model.add(MaxPool2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    # 二层卷积
    model.add(Conv2D(filters=32, kernel_size=ks, padding='same', activation='relu'))
    # 池化层
    model.add(MaxPool2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    # 三层卷积
    model.add(Conv2D(filters=64, kernel_size=ks, padding='same', activation='relu'))
    # 四层卷积
    model.add(Conv2D(filters=128, kernel_size=ks, padding='same', activation='relu'))
    # 池化层
    model.add(MaxPool2D(pool_size=(2, 2)))
    model.add(Dropout(0.25))
    # 平坦层
    model.add(Flatten())
    # 全连接层
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.25))
    # 激活函数
    model.add(Dense(10, activation='softmax'))
    model.summary()

    结果:

     

    5.模型评价

    • model.evaluate()
    • 交叉表与交叉矩阵
    • pandas.crosstab
    • seaborn.heatmap
    # 模型评价
    # 模型评估
    score = model.evaluate(x_test, y_test)[1]
    print('模型准确率=',score)
    # 预测值
    y_pre = model.predict_classes(x_test)
    y_pre[:10]
    
    # 交叉表和交叉矩阵
    y_test1 = np.argmax(y_test, axis=1).reshape(-1)
    y_true = np.array(y_test1)[0]
    y_true.shape
    # 交叉表查看预测数据与原数据对比
    pd.crosstab(y_true, y_pre, rownames=['true'], colnames=['predict'])
    
    # 交叉矩阵
    y_test1 = y_test1.tolist()[0]
    a = pd.crosstab(np.array(y_test1), y_pre, rownames=['Lables'], colnames=['predict'])
    df = pd.DataFrame(a)
    print(df)
    sns.heatmap(df, annot=True, cmap="Reds", linewidths=0.2, linecolor='G')

    结果:

     

  • 相关阅读:
    如何:为 Silverlight 客户端生成双工服务
    Microsoft Sync Framework 2.1 软件开发包 (SDK)
    Windows 下的安装phpMoAdmin
    asp.net安全检测工具 Padding Oracle 检测
    HTTP Basic Authentication for RESTFul Service
    Windows系统性能分析
    Windows Server AppFabric Management Pack for Operations Manager 2007
    Mongo Database 性能优化
    服务器未能识别 HTTP 标头 SOAPAction 的值
    TCP WAIT状态及其对繁忙的服务器的影响
  • 原文地址:https://www.cnblogs.com/wh008/p/13082400.html
Copyright © 2011-2022 走看看