zoukankan      html  css  js  c++  java
  • MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:460356155@qq.com

    在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(二)中,采用全连接神经网络(784-300-10),分别用非深度学习框架和基于pytorch实现,训练结果相当。

    这里采用卷积神经网络(CNN)中著名的LeNet-5网络来训练,除了网络定义部分外,其他代码基本和MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(二)相同。

    网络定义代码:

     1 #定义网络模型
     2 class LeNet(nn.Module):
     3     def __init__(self):
     4         super(LeNet, self).__init__()
     5 
     6         self.cnn = nn.Sequential(
     7             #卷积层1,单通道输入,6个卷积核,核大小5*5
     8             #经过该层图像大小变为28-5+1,24*24
     9             #经2*2最大池化,图像变为12*12
    10             nn.Conv2d(1, 6, 5),
    11             nn.ReLU(),
    12             nn.MaxPool2d(2),
    13 
    14             #卷积层2,6通道,16个卷积核,核大小5*5
    15             #经过该层图像变为12-5+1,8*8
    16             # 经2*2最大池化,图像变为4*4
    17             nn.Conv2d(6, 16, 5),
    18             nn.ReLU(),
    19             nn.MaxPool2d(2)
    20         )
    21 
    22         self.fc = nn.Sequential(
    23             # 16个feature,每个feature4*4
    24             nn.Linear(16 * 4 * 4, 120),
    25             nn.ReLU(),
    26             nn.Linear(120, 84),
    27             nn.ReLU(),
    28             nn.Linear(84, 10)
    29         )
    30 
    31     def forward(self, x):
    32         x = self.cnn(x)
    33         x = x.view(x.size()[0], -1)
    34         x = self.fc(x)
    35         return x

    网络训练结果准确率约在99%,LeNet-5比前面的全连接神经网络高1.x%。运行结果如下:

    train data num: 60000 , test data num: 10000
    epoch:0 i:999 loss:0.11399480700492859
    epoch:0 i:1999 loss:0.1237913966178894
    epoch:0 i:2999 loss:0.12948277592658997
    EPOCH:0, ACC:97.5

    epoch:1 i:999 loss:0.006639003753662109
    epoch:1 i:1999 loss:0.0011253952980041504
    epoch:1 i:2999 loss:0.03325369954109192
    EPOCH:1, ACC:98.35

    epoch:2 i:999 loss:0.0021111369132995605
    epoch:2 i:1999 loss:0.2714851200580597
    epoch:2 i:2999 loss:0.0016380250453948975
    EPOCH:2, ACC:98.64

    epoch:3 i:999 loss:0.00033468008041381836
    epoch:3 i:1999 loss:0.05128034949302673
    epoch:3 i:2999 loss:0.1222798228263855
    EPOCH:3, ACC:98.65

    epoch:4 i:999 loss:0.0006810426712036133
    epoch:4 i:1999 loss:0.002728283405303955
    epoch:4 i:2999 loss:0.000545889139175415
    EPOCH:4, ACC:98.89

    epoch:5 i:999 loss:0.006086885929107666
    epoch:5 i:1999 loss:0.07402010262012482
    epoch:5 i:2999 loss:0.03638958930969238
    EPOCH:5, ACC:98.93

    epoch:6 i:999 loss:0.0002015829086303711
    epoch:6 i:1999 loss:0.0004933476448059082
    epoch:6 i:2999 loss:0.03196592628955841
    EPOCH:6, ACC:99.02

    epoch:7 i:999 loss:0.01734447479248047
    epoch:7 i:1999 loss:2.9087066650390625e-05
    epoch:7 i:2999 loss:0.018512487411499023
    EPOCH:7, ACC:98.73

    epoch:8 i:999 loss:4.70280647277832e-05
    epoch:8 i:1999 loss:0.008362054824829102
    epoch:8 i:2999 loss:2.9206275939941406e-06
    EPOCH:8, ACC:98.84

    epoch:9 i:999 loss:0.00012737512588500977
    epoch:9 i:1999 loss:0.00020432472229003906
    epoch:9 i:2999 loss:0.00022774934768676758
    EPOCH:9, ACC:99.1

    MINIST pytorch LeNet-5 Train: EPOCH:10, BATCH_SZ:16, LR:0.05
    train spend time:  0:01:05.897404

    损失函数值变化曲线为:

  • 相关阅读:
    input标签上传文件处理。
    Radio单选框元素操作。
    CompletableFuture方法
    传播学 2
    传播学 1
    0
    紅軍不怕遠征難
    ~~~~~~~~~
    什么是企业战略
    论述提供公共咨询服务的两种主要方式。
  • 原文地址:https://www.cnblogs.com/zhengbiqing/p/10408358.html
Copyright © 2011-2022 走看看