zoukankan      html  css  js  c++  java
  • 15 手写数字识别-小数据集

    • from sklearn.datasets import load_digits
    • digits = load_digits()

    源代码:

     1 from sklearn.datasets import load_digits
     2 from sklearn.preprocessing import MinMaxScaler
     3 from sklearn.preprocessing import OneHotEncoder
     4 import numpy as np
     5 from sklearn.model_selection import train_test_split
     6 import matplotlib.pyplot as plt
     7 #1.手写数字数据集
     8 digits=load_digits() #获取数据
     9 #转换类型
    10 X_data=digits.data.astype(np.float32)
    11 Y_data=digits.target.astype(np.float32).reshape(-1,1)#将Y_data变为一列

    结果:

     

    2.图片数据预处理

    • x:归一化MinMaxScaler()
    • y:独热编码OneHotEncoder()或to_categorical
    • 训练集测试集划分
    • 张量结构

     源代码:

     1 #2.图片数据预处理
     2 scaler=MinMaxScaler()
     3 #x:归一化MinMaxScaler()
     4 X_data=scaler.fit_transform(X_data)
     5 
     6 #y:独热编码OneHotEncoder()
     7 Y=OneHotEncoder().fit_transform(Y_data).todense()
     8 
     9 #转换为图片的格式(batch,height,width,channels)
    10 X=X_data.reshape(-1,8,8,1)
    11 
    12 #训练集测试集划分
    13 x_train,x_test,y_train,y_test=train_test_split(X,Y,test_size=0.2,random_state=0,stratify=Y)
    14 print(x_train.shape,x_test.shape,y_train.shape,y_test.shape)

    结果:

    3.设计卷积神经网络结构

    • 绘制模型结构图,并说明设计依据。

    源代码:

     1 #3.设计卷积神经网络结构
     2 import os
     3 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
     4 # 导入相关包
     5 from tensorflow.keras.models import Sequential
     6 from tensorflow.keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPool2D
     7 
     8 # 建立模型
     9 model = Sequential()
    10 # 一层卷积
    11 model.add(Conv2D(filters=16,kernel_size=(5, 5),padding='same',input_shape=x_train.shape[1:],activation='relu'))
    12 # 池化层1
    13 model.add(MaxPool2D(pool_size=(2, 2)))
    14 model.add(Dropout(0.25))
    15 # 二层卷积
    16 model.add(Conv2D(filters=32,kernel_size=(5, 5),padding='same',activation='relu'))
    17 # 池化层2
    18 model.add(MaxPool2D(pool_size=(2, 2)))
    19 model.add(Dropout(0.25))
    20 # 三层卷积
    21 model.add(Conv2D(filters=64,kernel_size=(5, 5),padding='same',activation='relu'))
    22 # 四层卷积
    23 model.add(Conv2D(filters=128,kernel_size=(5, 5),padding='same',activation='relu'))
    24 # 池化层3
    25 model.add(MaxPool2D(pool_size=(2, 2)))
    26 model.add(Dropout(0.25))
    27 
    28 model.add(Flatten())  # 平坦层
    29 model.add(Dense(128, activation='relu'))  # 全连接层
    30 model.add(Dropout(0.25))
    31 model.add(Dense(10, activation='softmax')) # 激活函数
    32 
    33 model.summary()

    结果:

    4.模型训练

    • model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    • train_history = model.fit(x=X_train,y=y_train,validation_split=0.2, batch_size=300,epochs=10,verbose=2)

    源代码:

    1 #4.模型训练
    2 model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    3 train_history = model.fit(x=x_train,y=y_train,validation_split=0.2, batch_size=300,epochs=10,verbose=2)

    结果:

    第一次训练

    第二次训练

    5.模型评价

    • model.evaluate()
    • 交叉表与交叉矩阵
    • pandas.crosstab
    • seaborn.heatmap

    源代码:

     1 #5.模型评价
     2 score =model.evaluate(x_test,y_test)
     3 print(score)
     4 
     5 #预测值
     6 y_pred=model.predict_classes(x_test)
     7 y_pred[:10]
     8 y_test[:10]
     9 
    10 #交叉表查看预测数据与原数据对比
    11 import pandas as pd
    12 import seaborn as sns
    13 y_test1=np.argmax(y_test,axis=1).reshape(-1)
    14 y_test1=np.array(y_test1)[0]#记得要将数据提取为一维的 不然后面的会报错
    15 y_test1.shape
    16 y_pred.shape
    17 
    18 a=pd.crosstab(np.array(y_test1),y_pred,rownames=['lables'],colnames=['predict'])
    19 #转换成dataframe
    20 df=pd.DataFrame(a)
    21 sns.heatmap(df,annot=True,cmap="YlGnBu",linewidths=0.2,linecolor='G')
    22 plt.show()

    结果:

     

  • 相关阅读:
    JS创建类的方法--简单易懂有实例
    CommonJS, AMD, CMD是什么及区别--简单易懂有实例
    JS回调函数--简单易懂有实例
    单链表应用(2)--使用快慢指针,如何判断是否有环,环在哪个节点
    单链表应用(1)--使用快慢指针,找链表中间值
    自定义线性结构-有序Map
    C++中final和override
    双向链表翻转
    检查“()”是否匹配并返回深度
    是否存在K
  • 原文地址:https://www.cnblogs.com/tao614/p/13092270.html
Copyright © 2011-2022 走看看