zoukankan      html  css  js  c++  java
  • 15 手写数字识别-小数据集

    补交作业:

    4.K均值算法 

    12.朴素贝叶斯-垃圾邮件分类 

    第一次是看错时间忘记交作业了,第二次是家里没电,晚上12点多才有电,没交到作业

    1.手写数字数据集

    • from sklearn.datasets import load_digits
    • digits = load_digits()
    1 from sklearn.datasets import load_digits
    2 import numpy as np
    3 
    4 digits = load_digits()
    5 x_data = digits.data.astype(np.float32)
    6 y_data = digits.target.astype(np.float32).reshape(-1, 1)

    2.图片数据预处理

    • x:归一化MinMaxScaler()
    • y:独热编码OneHotEncoder()或to_categorical
    • 训练集测试集划分
    • 张量结构
    复制代码
     1 from sklearn.model_selection import train_test_split
     2 from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
     3 
     4 scaler = MinMaxScaler()
     5 x_data = scaler.fit_transform(x_data)
     6 print(x_data)
     7 x = x_data.reshape(-1, 8, 8, 1)  # 转换为图片格式
     8 y = OneHotEncoder().fit_transform(y_data).todense()
     9 # 训练集测试集划分
    10 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0, stratify=y)
    11 print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)
    复制代码

    3.设计卷积神经网络结构

    • 绘制模型结构图,并说明设计依据。
    复制代码
     1 from tensorflow.keras.models import Sequential
     2 from tensorflow.keras.layers import Dense,Dropout,Flatten,Conv2D,MaxPool2D
     3 import matplotlib.pyplot as plt
     4 # 3.设计卷积神经网络结构
     5 # 建立模型
     6 model = Sequential()
     7 ks = [3, 3]  # 卷积核
     8 # 一层卷积
     9 model.add(Conv2D(filters=16, kernel_size=ks, padding='same', input_shape=x_train.shape[1:], activation='relu'))
    10 # 池化层
    11 model.add(MaxPool2D(pool_size=(2, 2)))
    12 model.add(Dropout(0.25))
    13 # 二层卷积
    14 model.add(Conv2D(filters=32, kernel_size=ks, padding='same', activation='relu'))
    15 # 池化层
    16 model.add(MaxPool2D(pool_size=(2, 2)))
    17 model.add(Dropout(0.25))
    18 # 三层卷积
    19 model.add(Conv2D(filters=64, kernel_size=ks, padding='same', activation='relu'))
    20 # 四层卷积
    21 model.add(Conv2D(filters=128, kernel_size=ks, padding='same', activation='relu'))
    22 # 池化层
    23 model.add(MaxPool2D(pool_size=(2, 2)))
    24 model.add(Dropout(0.25))
    25 # 平坦层
    26 model.add(Flatten())
    27 # 全连接层
    28 model.add(Dense(128, activation='relu'))
    29 model.add(Dropout(0.25))
    30 # 激活函数
    31 model.add(Dense(10, activation='softmax'))
    32 model.summary()
    复制代码

    4.模型训练

    复制代码
     1 #绘制模型结构图
     2 def show_train_history(train_history, train, validation):
     3     plt.plot(train_history.history[train])
     4     plt.plot(train_history.history[validation])
     5     plt.title('Train History')
     6     plt.ylabel('train')
     7     plt.xlabel('epoch')
     8     plt.legend(['train', 'validation'], loc='upper left')
     9     plt.show()
    10 model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    11 train_history = model.fit(x=x_train, y=y_train, validation_split=0.2, batch_size=300, epochs=10, verbose=2)
    12 # 准确率
    13 show_train_history(train_history, 'accuracy', 'val_accuracy')
    14 # 损失率
    15 show_train_history(train_history, 'loss', 'val_loss')
    复制代码
    准确率:

    损失率:

    5.模型评价

    • model.evaluate()
    • 交叉表与交叉矩阵
    • pandas.crosstab
    • seaborn.heatmap
    复制代码
     1 import pandas as pd
     2 import seaborn as sns
     3 # 5、模型评价
     4 #模型评估
     5 score = model.evaluate(x_test, y_test)[1]
     6 print('模型准确率=',score)
     7 # 预测值
     8 y_pre = model.predict_classes(x_test)
     9 y_pre[:10]
    10 y_test1 = np.argmax(y_test, axis=1).reshape(-1)
    11 y_true = np.array(y_test1)[0]
    12 y_true.shape
    13 pd.crosstab(y_true, y_pre, rownames=['true'], colnames=['predict'])
    14 # 交叉矩阵
    15 y_test1 = y_test1.tolist()[0]
    16 a = pd.crosstab(np.array(y_test1), y_pre, rownames=['Lables'], colnames=['predict'])
    17 df = pd.DataFrame(a)
    18 print(df)
    19 sns.heatmap(df, annot=True, cmap="Reds", linewidths=0.2, linecolor='G')
    20 plt.show()
    复制代码

  • 相关阅读:
    用python40行代码编写的计算器
    用Python语言设计GUI界面
    win7下安装Linux实现双系统全攻略
    Dreamweaver_CS6安装与破解,手把手教程
    windows Server 2008各版本有何区别?
    如何查看路由器中的pppoe拨号密码?
    xp远程桌面连接最大用户数怎么设置?
    网站的盈利模式
    linux 下安装mysql-5.7.16
    BroadcastReceiver接收电量变化的广播-------在代码中动态创建接受者
  • 原文地址:https://www.cnblogs.com/a188182/p/13114751.html
Copyright © 2011-2022 走看看