zoukankan      html  css  js  c++  java
  • LeNet-5 pytorch+torchvision+visdom

      # ====================LeNet-5_main.py===============
    # pytorch+torchvision+visdom
      1 # -*- coding: utf-8 -*-
      2 """
      3 Created on Sun May 26 22:53:52 2019
      4 
      5 @author: jiangshan
      6 """
      7 #A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset.
      8 import torch
      9 import torch.nn as nn
     10 import torch.optim as optim
     11 from torchvision.datasets.mnist import MNIST
     12 import torchvision.transforms as transforms
     13 from torch.utils.data import DataLoader
     14 import visdom
     15 from collections import OrderedDict
     16 
     17 class LeNet5(nn.Module):
     18     """
     19     Input - 1x32x32
     20     C1 - 6@28x28 (5x5 kernel)
     21     relu
     22     S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
     23     C3 - 16@10x10 (5x5 kernel, complicated shit)
     24     relu
     25     S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
     26     C5 - 120@1x1 (5x5 kernel)
     27     F6 - 84
     28     relu
     29     F7 - 10 (Output)
     30     """
     31     def __init__(self):
     32         super(LeNet5, self).__init__()
     33 
     34         self.convnet = nn.Sequential(OrderedDict([
     35             ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
     36             ('relu1', nn.ReLU()),
     37             ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
     38             ('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
     39             ('relu3', nn.ReLU()),
     40             ('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
     41             ('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
     42             ('relu5', nn.ReLU())
     43         ]))
     44 
     45         self.fc = nn.Sequential(OrderedDict([
     46             ('f6', nn.Linear(120, 84)),
     47             ('relu6', nn.ReLU()),
     48             ('f7', nn.Linear(84, 10)),
     49             ('sig7', nn.LogSoftmax(dim=-1))
     50         ]))
     51 
     52     def forward(self, img):
     53         output = self.convnet(img)
     54         output = output.view(img.size(0), -1)
     55         output = self.fc(output)
     56         return output
     57 
     58 
     59 viz = visdom.Visdom()
     60 data_train = MNIST('./data/mnist',
     61                    download=True,
     62                    transform=transforms.Compose([
     63                        transforms.Resize((32, 32)),
     64                        transforms.ToTensor()]))
     65 data_test = MNIST('./data/mnist',
     66                   train=False,
     67                   download=True,
     68                   transform=transforms.Compose([
     69                       transforms.Resize((32, 32)),
     70                       transforms.ToTensor()]))
     71 data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
     72 data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)
     73 
     74 net = LeNet5()
     75 criterion = nn.CrossEntropyLoss()
     76 optimizer = optim.Adam(net.parameters(), lr=2e-3)
     77 
     78 cur_batch_win = None
     79 cur_batch_win_opts = {
     80     'title': 'Epoch Loss Trace',
     81     'xlabel': 'Batch Number',
     82     'ylabel': 'Loss',
     83     'width': 1200,
     84     'height': 600,
     85 }
     86 
     87 
     88 def train(epoch):
     89     global cur_batch_win
     90     net.train()
     91     loss_list, batch_list = [], []
     92     for i, (images, labels) in enumerate(data_train_loader):
     93         optimizer.zero_grad()
     94 
     95         output = net(images)
     96 
     97         loss = criterion(output, labels)
     98 
     99         loss_list.append(loss.detach().cpu().item())
    100         batch_list.append(i+1)
    101 
    102         if i % 10 == 0:
    103             print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item()))
    104 
    105         # Update Visualization
    106         if viz.check_connection():
    107             cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
    108                                      win=cur_batch_win, name='current_batch_loss',
    109                                      update=(None if cur_batch_win is None else 'replace'),
    110                                      opts=cur_batch_win_opts)
    111         loss.backward()
    112         optimizer.step()
    113 
    114 
    115 def test():
    116     net.eval()
    117     total_correct = 0
    118     avg_loss = 0.0
    119     for i, (images, labels) in enumerate(data_test_loader):
    120         output = net(images)
    121         avg_loss += criterion(output, labels).sum()
    122         pred = output.detach().max(1)[1]
    123         total_correct += pred.eq(labels.view_as(pred)).sum()
    124 
    125     avg_loss /= len(data_test)
    126     print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))
    127 
    128 
    129 def train_and_test(epoch):
    130     train(epoch)
    131     test()
    132 
    133 
    134 def main():
    135     for e in range(1, 16):
    136         train_and_test(e)
    137 
    138 
    139 if __name__ == '__main__':
    140     main()

    先开启visdom 进行可视化

    python -m visdom.server

    运行程序

    python LeNet-5_main.py

    打开浏览器查看live graph

    http://localhost:8097 

  • 相关阅读:
    什么是multipart/form-data请求
    jquery.fancybox
    laravel上传图片报错
    php框架安装
    ExtJS入门教程02,form也可以很优雅
    ExtJS入门教程01,Window如此简单,你怎能不会?
    Ext.Net学习笔记24:在ASP.NET MVC中使用Ext.Net
    Ext.Net学习笔记23:Ext.Net TabPanel用法详解
    Ext.Net学习笔记22:Ext.Net Tree 用法详解
    Ext.Net学习笔记21:Ext.Net FormPanel 字段验证(validation)
  • 原文地址:https://www.cnblogs.com/jeshy/p/10928315.html
Copyright © 2011-2022 走看看