zoukankan      html  css  js  c++  java
  • 15-手写数字识别-小数据集

    1.手写数字数据集

    • from sklearn.datasets import load_digits
    • digits = load_digits()

     

    2.图片数据预处理

    • x:归一化MinMaxScaler()
    • y:独热编码OneHotEncoder()或to_categorical
    • 训练集测试集划分
    • 张量结构
    from sklearn.datasets import load_digits
    from sklearn.preprocessing import MinMaxScaler
    from sklearn.preprocessing import OneHotEncoder
    import numpy as np
    from sklearn.model_selection import train_test_split
    import matplotlib.pyplot as plt
    
    digits=load_digits() #获取数据集
    
    X_data=digits.data.astype(np.float32)#样本数据
    print("样本数据:
    ",X_data)
    #对x归一化处理MinMaxScaler()
    scaler=MinMaxScaler()
    X_data=scaler.fit_transform(X_data)
    print("归一化后处理后的样本数据:
    ",X_data)
    
    Y_data=digits.target.astype(np.float32).reshape(-1,1)#将Y_data变为一列
    print("样本标签:
    ",Y_data)
    #对y进行独热编码OneHotEncoder()
    Y=OneHotEncoder().fit_transform(Y_data).todense()
    print("独热编码后的样本标签:
    ",Y)
    #转换为图片的格式
    X=X_data.reshape(-1,8,8,1)
    
    #训练集测试集划分
    x_train,x_test,y_train,y_test=train_test_split(X,Y,test_size=0.2,random_state=0,stratify=Y)
    print('训练集数据:',x_train.shape,'测试集数据:',x_test.shape)
    print('
    训练集标签:',y_train.shape,'测试集标签:',y_test.shape)

    运行结果:

    3.设计卷积神经网络结构

    • 绘制模型结构图,并说明设计依据。

              

    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPool2D
    # 建立模型
    model = Sequential()
    
    # C1卷积层
    model.add(
         Conv2D(
            filters=16,#输出空间的维度
            kernel_size=(5, 5),#卷积核大小
            padding='same',#填充边界 
            input_shape=x_train.shape[1:],
            activation='relu'))
    
    # S2池化层
    model.add(MaxPool2D(pool_size=(2, 2)))
    
    # drop层(防止过拟合)
    model.add(Dropout(0.25))
    
    # C3卷积层
    model.add(
        Conv2D(
            filters=32,
            kernel_size=(5, 5), 
            padding='same',
            activation='relu'))
    
    
    # S4池化层
    model.add(MaxPool2D(pool_size=(2, 2)))
    
    model.add(Dropout(0.25))
    
    #C5卷积层
    model.add(
        Conv2D(
            filters=64,
            kernel_size=(5, 5), 
            padding='same',
            activation='relu'))
    
    #C6卷积层
    model.add(
        Conv2D(
            filters=128,
            kernel_size=(5, 5), 
            padding='same',
            activation='relu'))
    
    # S7池化层
    model.add(MaxPool2D(pool_size=(2, 2)))
    
    model.add(Dropout(0.25))
    
    #平坦层
    model.add(Flatten())
    
    #全连接层
    model.add(Dense(128, activation='relu'))
    
    model.add(Dropout(0.25)) 
    
    #激活函数
    model.add(Dense(10, activation='softmax'))
    
    model.summary()

    运行结果:

    4.模型训练

    • model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    • train_history = model.fit(x=X_train,y=y_train,validation_split=0.2, batch_size=300,epochs=10,verbose=2)
    #模型训练
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    train_history = model.fit(x=x_train,y=y_train,
                              validation_split=0.2, 
                              batch_size=300,#每次训练的数据量
                              epochs=10,#训练迭代次数
                              verbose=2)

    运行结果:

    5.模型评价

    • model.evaluate()
    • 交叉表与交叉矩阵
    • pandas.crosstab
    • seaborn.heatmap
    # 模型评价
    score =model.evaluate(x_test,y_test)
    print("模型评价:",score)
    
    #预测值
    y_pred=model.predict_classes(x_test)
    print("预测值:",y_pred)
    
    #交叉矩阵查看预测数据与原数据对比
    import pandas as pd
    import seaborn as sns
    #标签数值化
    y_test1=np.argmax(y_test,axis=1).reshape(-1)
    y_ture=np.array(y_test1[0]).reshape(-1)
    
    a=pd.crosstab(y_ture,y_pred,rownames=['lables'],colnames=['predict'])
    #转换为数据框
    df=pd.DataFrame(a)
    #绘制热力图
    sns.heatmap(df,annot=True,cmap="YlGnBu",linewidths=0.2,linecolor='G')
    plt.show()

    运行结果:

     

  • 相关阅读:
    Jmeter 脚本录制
    Scrapy 爬虫模拟登陆的3种策略
    Scrapy Shell
    Ipython
    XPath helper
    python3 接口测试数据驱动之操作mysql数据库
    Pandas 基础(17)
    Pandas 基础(16)
    在 Laravel 项目中使用 Elasticsearch 做引擎,scout 全文搜索(小白出品, 绝对白话)
    Pandas 基础(15)
  • 原文地址:https://www.cnblogs.com/a1120139442/p/13111790.html
Copyright © 2011-2022 走看看