zoukankan      html  css  js  c++  java
  • LeNet-5 pytorch+torchvision+visdom

      # ====================LeNet-5_main.py===============
    # pytorch+torchvision+visdom
      1 # -*- coding: utf-8 -*-
      2 """
      3 Created on Sun May 26 22:53:52 2019
      4 
      5 @author: jiangshan
      6 """
      7 #A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset.
      8 import torch
      9 import torch.nn as nn
     10 import torch.optim as optim
     11 from torchvision.datasets.mnist import MNIST
     12 import torchvision.transforms as transforms
     13 from torch.utils.data import DataLoader
     14 import visdom
     15 from collections import OrderedDict
     16 
     17 class LeNet5(nn.Module):
     18     """
     19     Input - 1x32x32
     20     C1 - 6@28x28 (5x5 kernel)
     21     relu
     22     S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
     23     C3 - 16@10x10 (5x5 kernel, complicated shit)
     24     relu
     25     S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
     26     C5 - 120@1x1 (5x5 kernel)
     27     F6 - 84
     28     relu
     29     F7 - 10 (Output)
     30     """
     31     def __init__(self):
     32         super(LeNet5, self).__init__()
     33 
     34         self.convnet = nn.Sequential(OrderedDict([
     35             ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
     36             ('relu1', nn.ReLU()),
     37             ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
     38             ('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
     39             ('relu3', nn.ReLU()),
     40             ('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
     41             ('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
     42             ('relu5', nn.ReLU())
     43         ]))
     44 
     45         self.fc = nn.Sequential(OrderedDict([
     46             ('f6', nn.Linear(120, 84)),
     47             ('relu6', nn.ReLU()),
     48             ('f7', nn.Linear(84, 10)),
     49             ('sig7', nn.LogSoftmax(dim=-1))
     50         ]))
     51 
     52     def forward(self, img):
     53         output = self.convnet(img)
     54         output = output.view(img.size(0), -1)
     55         output = self.fc(output)
     56         return output
     57 
     58 
     59 viz = visdom.Visdom()
     60 data_train = MNIST('./data/mnist',
     61                    download=True,
     62                    transform=transforms.Compose([
     63                        transforms.Resize((32, 32)),
     64                        transforms.ToTensor()]))
     65 data_test = MNIST('./data/mnist',
     66                   train=False,
     67                   download=True,
     68                   transform=transforms.Compose([
     69                       transforms.Resize((32, 32)),
     70                       transforms.ToTensor()]))
     71 data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
     72 data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)
     73 
     74 net = LeNet5()
     75 criterion = nn.CrossEntropyLoss()
     76 optimizer = optim.Adam(net.parameters(), lr=2e-3)
     77 
     78 cur_batch_win = None
     79 cur_batch_win_opts = {
     80     'title': 'Epoch Loss Trace',
     81     'xlabel': 'Batch Number',
     82     'ylabel': 'Loss',
     83     'width': 1200,
     84     'height': 600,
     85 }
     86 
     87 
     88 def train(epoch):
     89     global cur_batch_win
     90     net.train()
     91     loss_list, batch_list = [], []
     92     for i, (images, labels) in enumerate(data_train_loader):
     93         optimizer.zero_grad()
     94 
     95         output = net(images)
     96 
     97         loss = criterion(output, labels)
     98 
     99         loss_list.append(loss.detach().cpu().item())
    100         batch_list.append(i+1)
    101 
    102         if i % 10 == 0:
    103             print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item()))
    104 
    105         # Update Visualization
    106         if viz.check_connection():
    107             cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
    108                                      win=cur_batch_win, name='current_batch_loss',
    109                                      update=(None if cur_batch_win is None else 'replace'),
    110                                      opts=cur_batch_win_opts)
    111         loss.backward()
    112         optimizer.step()
    113 
    114 
    115 def test():
    116     net.eval()
    117     total_correct = 0
    118     avg_loss = 0.0
    119     for i, (images, labels) in enumerate(data_test_loader):
    120         output = net(images)
    121         avg_loss += criterion(output, labels).sum()
    122         pred = output.detach().max(1)[1]
    123         total_correct += pred.eq(labels.view_as(pred)).sum()
    124 
    125     avg_loss /= len(data_test)
    126     print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))
    127 
    128 
    129 def train_and_test(epoch):
    130     train(epoch)
    131     test()
    132 
    133 
    134 def main():
    135     for e in range(1, 16):
    136         train_and_test(e)
    137 
    138 
    139 if __name__ == '__main__':
    140     main()

    先开启visdom 进行可视化

    python -m visdom.server

    运行程序

    python LeNet-5_main.py

    打开浏览器查看live graph

    http://localhost:8097 

  • 相关阅读:
    MAC OpenGL 环境搭建
    C++中调用OC代码
    XCode快捷键使用
    【iOS】史上最全的iOS持续集成教程 (下)
    【iOS】史上最全的iOS持续集成教程 (上)
    pod 指令无效
    iOS面试题总结(持续更新)
    数据结构与算法思维导图
    Swift编码规范总结
    同步异步执行问题
  • 原文地址:https://www.cnblogs.com/jeshy/p/10928315.html
Copyright © 2011-2022 走看看