zoukankan      html  css  js  c++  java
  • MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(二)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:460356155@qq.com

    在前一篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)中,直接用python创建全连接神经网络模型进行深度学习训练,这样可以对神经网络有较为深刻的认识。

    但是在实际应用中,一般都是采用各种深度学习框架来开展人工智能项目,以下就采用pytorch来实现前一篇文章中的全连接神经网络(784-300-10)。

      1 # -*- coding:utf-8 -*-
      2 
      3 u"""pytorch LineNet神经网络训练学习MINIST"""
      4 
      5 __author__ = 'zhengbiqing 460356155@qq.com'
      6 
      7 
      8 import torch as t
      9 import torchvision as tv
     10 import torch.nn as nn
     11 import torch.nn.functional as F
     12 import torchvision.transforms as transforms
     13 from torch.autograd import Variable
     14 import matplotlib.pyplot as plt
     15 import datetime
     16 
     17 
     18 #是否训练网络
     19 TRAIN = True
     20 
     21 #是否保存网络
     22 SAVE_PARA = False
     23 
     24 #学习率和训练次数
     25 LR = 0.05
     26 EPOCH = 10
     27 
     28 #训练每批次的样本数
     29 BATCH_SZ = 16
     30 
     31 #样本读取线程数
     32 WORKERS = 4
     33 
     34 #网络参赛保存文件名
     35 PARAS_FN = 'minist_linenet_params.pkl'
     36 
     37 #minist数据存放位置
     38 ROOT = '/home/zbq/pytorch/minist'
     39 
     40 
     41 #定义网络模型
     42 class LineNet(nn.Module):
     43     def __init__(self):
     44         super(LineNet, self).__init__()
     45 
     46         self.fc = nn.Sequential(
     47             nn.Linear(28*28, 300),
     48             nn.ReLU(),
     49             nn.Linear(300, 10)
     50         )
     51 
     52     def forward(self, x):
     53         #x是2维tensor,转换为1维向量
     54         x = x.view(x.size()[0], -1)
     55         x = self.fc(x)
     56         return x
     57 
     58 
     59 '''
     60 训练并测试网络
     61 net:网络模型
     62 train_data_load:训练数据集
     63 test_data_load:测试数据集
     64 epochs:训练迭代次数
     65 save:是否保存训练结果
     66 '''
     67 def net_train(net, train_data_load, test_data_load, epochs, save):
     68     start_time = datetime.datetime.now()
     69 
     70     loss_list = []
     71 
     72     for epoch in range(epochs):
     73         for i, data in enumerate(train_data_load, 0):
     74             img, label = data
     75             img, label = Variable(img), Variable(label)
     76             img, label = img.cuda(), label.cuda()
     77 
     78             optimizer.zero_grad()
     79 
     80             pre = net(img)
     81             loss = loss_func(pre, label)
     82             loss.backward()
     83 
     84             optimizer.step()
     85 
     86             #显示损失函数值的变化
     87             loss_data = loss.data.item()
     88             if i % 1000 == 999:
     89                 print('epoch:{epoch} i:{i} loss:{loss}'.format(epoch=epoch, i=i, loss=loss_data))
     90 
     91             if i % 100 == 99:
     92                 loss_list.append(loss_data)
     93 
     94         # 每个epoch结束后用测试集检查识别准确度
     95         net_test(epoch, test_data_load)
     96 
     97     print('MINIST pytorch LineNet Train: EPOCH:{epochs}, BATCH_SZ:{batch_sz}, LR:{lr}'.format(epochs=epochs, batch_sz=BATCH_SZ, lr=LR))
     98     print('train spend time: ', datetime.datetime.now() - start_time)
     99 
    100     if save:
    101         t.save(net.state_dict(), PARAS_FN)
    102 
    103     #显示目标函数值的变化曲线
    104     plt.plot(loss_list)
    105     plt.show()
    106 
    107 
    108 '''
    109 用测试集检查准确率
    110 '''
    111 def net_test(epoch, test_data_load):
    112     ok = 0
    113 
    114     for i, data in enumerate(test_data_load):
    115         img, label = data
    116         img, label = Variable(img), Variable(label)
    117         img, label = img.cuda(), label.cuda()
    118 
    119         outs = net(img)
    120         _, pre = t.max(outs.data, 1)
    121         ok += (pre == label).sum()
    122 
    123     acc = ok.item() * 100 / (len(test_data_load) * BATCH_SZ)
    124 
    125     print('EPOCH:{epoch}, ACC:{acc}
    '.format(epoch=epoch, acc=acc))
    126 
    127 
    128 #图像数值转换,ToTensor源码注释
    129 """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
    130     Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    131     [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    132     """
    133 #归一化,把[0.0, 1.0]变换为[-1,1], ([0, 1] - 0.5) / 0.5 = [-1, 1]
    134 transform = tv.transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
    135 
    136 #定义数据集
    137 train_data = tv.datasets.MNIST(root=ROOT, train=True, download=True, transform=transform)
    138 test_data = tv.datasets.MNIST(root=ROOT, train=False, download=False, transform=transform)
    139 
    140 train_load = t.utils.data.DataLoader(train_data, batch_size=BATCH_SZ, shuffle=True, num_workers=WORKERS)
    141 test_load = t.utils.data.DataLoader(test_data, batch_size=BATCH_SZ, shuffle=False, num_workers=WORKERS)
    142 
    143 print('train data num:', len(train_data), ', test data num:', len(test_data))
    144 
    145 
    146 net = LineNet()
    147 net.cuda()
    148 
    149 loss_func = nn.CrossEntropyLoss()
    150 optimizer = t.optim.SGD(net.parameters(), lr=LR)
    151 
    152 if TRAIN:
    153     net_train(net, train_load, test_load, EPOCH, SAVE_PARA)
    154 else:
    155     net.load_state_dict(t.load(PARAS_FN))
    156     net_test(0, test_load)

    网络训练结果准确率基本在97%~98%,和前一篇MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(一)相同网络结构的全连接神经网络相当,但是因为这里采用GPU运算,训练时间降低到1/8。

    此外,借助pytorch,代码更简单。

    运行结果如下:

    train data num: 60000 , test data num: 10000
    epoch:0 i:999 loss:0.3457891643047333
    epoch:0 i:1999 loss:0.09639787673950195
    epoch:0 i:2999 loss:0.27898865938186646
    EPOCH:0, ACC:94.84

    epoch:1 i:999 loss:0.33745211362838745
    epoch:1 i:1999 loss:0.11106520891189575
    epoch:1 i:2999 loss:0.21725007891654968
    EPOCH:1, ACC:96.42

    epoch:2 i:999 loss:0.3825737535953522
    epoch:2 i:1999 loss:0.02866300940513611
    epoch:2 i:2999 loss:0.11832481622695923
    EPOCH:2, ACC:96.77

    epoch:3 i:999 loss:0.11886310577392578
    epoch:3 i:1999 loss:0.012149035930633545
    epoch:3 i:2999 loss:0.030409961938858032
    EPOCH:3, ACC:97.2

    epoch:4 i:999 loss:0.008915185928344727
    epoch:4 i:1999 loss:0.008089780807495117
    epoch:4 i:2999 loss:0.0005310177803039551
    EPOCH:4, ACC:97.6

    epoch:5 i:999 loss:0.02993696928024292
    epoch:5 i:1999 loss:0.01784616708755493
    epoch:5 i:2999 loss:0.10544028878211975
    EPOCH:5, ACC:97.6

    epoch:6 i:999 loss:0.008486062288284302
    epoch:6 i:1999 loss:0.0334945023059845
    epoch:6 i:2999 loss:0.00291365385055542
    EPOCH:6, ACC:97.37

    epoch:7 i:999 loss:0.0062919557094573975
    epoch:7 i:1999 loss:0.0003241896629333496
    epoch:7 i:2999 loss:0.0006818175315856934
    EPOCH:7, ACC:97.23

    epoch:8 i:999 loss:0.0007421970367431641
    epoch:8 i:1999 loss:0.005641639232635498
    epoch:8 i:2999 loss:0.005949795246124268
    EPOCH:8, ACC:97.7

    epoch:9 i:999 loss:0.024028539657592773
    epoch:9 i:1999 loss:0.005388796329498291
    epoch:9 i:2999 loss:0.0029097795486450195
    EPOCH:9, ACC:97.39

    MINIST pytorch LineNet Train: EPOCH:10, BATCH_SZ:16, LR:0.05
    train spend time:  0:00:43.183836

    损失函数值变化曲线为:

  • 相关阅读:
    【NOIP模拟】寻找
    【NOIP模拟】行走
    【UVA11795】 Mega Man's Mission
    【UVA11825】Hackers' Crackdown
    【UVA1252】Twenty Questions
    BZOJ1718: [Usaco2006 Jan] Redundant Paths 分离的路径
    BZOJ1151: [CTSC2007]动物园zoo
    BZOJ1123: [POI2008]BLO
    BZOJ1040: [ZJOI2008]骑士
    POJ3417:Network
  • 原文地址:https://www.cnblogs.com/zhengbiqing/p/10408225.html
Copyright © 2011-2022 走看看