zoukankan      html  css  js  c++  java
  • 15 手写数字识别-小数据集

    1.手写数字数据集

    from sklearn.datasets import load_digits

    digits = load_digits()

     

    2.图片数据预处理

    x:归一化MinMaxScaler()

    y:独热编码OneHotEncoder()或to_categorical

    将分类特征的每个元素转化为一个可以用来计算的值

    训练集测试集划分

    张量结构

     

     

    3.设计卷积神经网络结构

    绘制模型结构图,并说明设计依据。

     1 # 导入相关包
     2 # sequential设计层数
     3 from tensorflow.keras.models import Sequential
     4 from tensorflow.keras.layers import Dense,Dropout,Flatten,Conv2D,MaxPool2D
     5 
     6 #建立模型
     7 model=Sequential()
     8 
     9 ks=(3,3)
    10 ips=X_train.shape[1:]
    11 
    12 #一层卷积
    13 model.add(
    14     Conv2D(filters=16,          # 卷积核的个数
    15            kernel_size=ks,     # 卷积核大小
    16            padding='same',     # 保证卷积核大小,不够不算
    17            input_shape=ips,    
    18            activation='relu')) # activation: 激活函数,'relu','sigmoid'等
    19 
    20 #池化层1
    21 model.add(MaxPool2D(pool_size=(2,2)))
    22 model.add(Dropout(0.25))
    23 
    24 #二层卷积
    25 model.add(
    26     Conv2D(filters=32,kernel_size=ks,padding='same',activation='relu'))
    27 
    28 #池化层2
    29 model.add(MaxPool2D(pool_size=(2,2)))
    30 model.add(Dropout(0.25))
    31 
    32 #三层卷积
    33 model.add(
    34     Conv2D(filters=64,kernel_size=ks,padding='same',activation='relu'))
    35 
    36 #池化层3
    37 model.add(MaxPool2D(pool_size=(2,2)))
    38 model.add(Dropout(0.25))
    39 
    40 model.add(Flatten()) #平坦层
    41 model.add(Dense(128,activation='relu')) #dense全连接层
    42 model.add(Dropout(0.25))
    43 model.add(Dense(10,activation='softmax')) #激活函数  softmax分类
    44 
    45 model.summary()

    4.模型训练

    • model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    • train_history = model.fit(x=X_train,y=y_train,validation_split=0.2, batch_size=300,epochs=10,verbose=2)

     

    5.模型评价

    model.evaluate()

    交叉表与交叉矩阵

    pandas.crosstab

    从交叉表可见:少部分数字被预测错了

    seaborn.heatmap

  • 相关阅读:
    tornado源码分析-多进程
    create a cocos2d-x-3.0 project in Xcode
    记录自己的傻逼的错误:找不到或无法载入主类
    MVC5 Entity Framework学习之实现主要的CRUD功能
    Linux中实现多网卡绑定总结
    it码农之心灵鸡汤(一)
    【高级算法】遗传算法解决3SAT问题(C++实现)
    MySQL-分区表-1
    OpenSift源代码编译过程记录
    Android Studio 视图解析
  • 原文地址:https://www.cnblogs.com/linyanli/p/13091806.html
Copyright © 2011-2022 走看看