zoukankan      html  css  js  c++  java
  • LeNet-5 pytorch+torchvision+visdom

      # ====================LeNet-5_main.py===============
    # pytorch+torchvision+visdom
      1 # -*- coding: utf-8 -*-
      2 """
      3 Created on Sun May 26 22:53:52 2019
      4 
      5 @author: jiangshan
      6 """
      7 #A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset.
      8 import torch
      9 import torch.nn as nn
     10 import torch.optim as optim
     11 from torchvision.datasets.mnist import MNIST
     12 import torchvision.transforms as transforms
     13 from torch.utils.data import DataLoader
     14 import visdom
     15 from collections import OrderedDict
     16 
     17 class LeNet5(nn.Module):
     18     """
     19     Input - 1x32x32
     20     C1 - 6@28x28 (5x5 kernel)
     21     relu
     22     S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
     23     C3 - 16@10x10 (5x5 kernel, complicated shit)
     24     relu
     25     S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
     26     C5 - 120@1x1 (5x5 kernel)
     27     F6 - 84
     28     relu
     29     F7 - 10 (Output)
     30     """
     31     def __init__(self):
     32         super(LeNet5, self).__init__()
     33 
     34         self.convnet = nn.Sequential(OrderedDict([
     35             ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
     36             ('relu1', nn.ReLU()),
     37             ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
     38             ('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
     39             ('relu3', nn.ReLU()),
     40             ('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
     41             ('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
     42             ('relu5', nn.ReLU())
     43         ]))
     44 
     45         self.fc = nn.Sequential(OrderedDict([
     46             ('f6', nn.Linear(120, 84)),
     47             ('relu6', nn.ReLU()),
     48             ('f7', nn.Linear(84, 10)),
     49             ('sig7', nn.LogSoftmax(dim=-1))
     50         ]))
     51 
     52     def forward(self, img):
     53         output = self.convnet(img)
     54         output = output.view(img.size(0), -1)
     55         output = self.fc(output)
     56         return output
     57 
     58 
     59 viz = visdom.Visdom()
     60 data_train = MNIST('./data/mnist',
     61                    download=True,
     62                    transform=transforms.Compose([
     63                        transforms.Resize((32, 32)),
     64                        transforms.ToTensor()]))
     65 data_test = MNIST('./data/mnist',
     66                   train=False,
     67                   download=True,
     68                   transform=transforms.Compose([
     69                       transforms.Resize((32, 32)),
     70                       transforms.ToTensor()]))
     71 data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
     72 data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)
     73 
     74 net = LeNet5()
     75 criterion = nn.CrossEntropyLoss()
     76 optimizer = optim.Adam(net.parameters(), lr=2e-3)
     77 
     78 cur_batch_win = None
     79 cur_batch_win_opts = {
     80     'title': 'Epoch Loss Trace',
     81     'xlabel': 'Batch Number',
     82     'ylabel': 'Loss',
     83     'width': 1200,
     84     'height': 600,
     85 }
     86 
     87 
     88 def train(epoch):
     89     global cur_batch_win
     90     net.train()
     91     loss_list, batch_list = [], []
     92     for i, (images, labels) in enumerate(data_train_loader):
     93         optimizer.zero_grad()
     94 
     95         output = net(images)
     96 
     97         loss = criterion(output, labels)
     98 
     99         loss_list.append(loss.detach().cpu().item())
    100         batch_list.append(i+1)
    101 
    102         if i % 10 == 0:
    103             print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item()))
    104 
    105         # Update Visualization
    106         if viz.check_connection():
    107             cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
    108                                      win=cur_batch_win, name='current_batch_loss',
    109                                      update=(None if cur_batch_win is None else 'replace'),
    110                                      opts=cur_batch_win_opts)
    111         loss.backward()
    112         optimizer.step()
    113 
    114 
    115 def test():
    116     net.eval()
    117     total_correct = 0
    118     avg_loss = 0.0
    119     for i, (images, labels) in enumerate(data_test_loader):
    120         output = net(images)
    121         avg_loss += criterion(output, labels).sum()
    122         pred = output.detach().max(1)[1]
    123         total_correct += pred.eq(labels.view_as(pred)).sum()
    124 
    125     avg_loss /= len(data_test)
    126     print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))
    127 
    128 
    129 def train_and_test(epoch):
    130     train(epoch)
    131     test()
    132 
    133 
    134 def main():
    135     for e in range(1, 16):
    136         train_and_test(e)
    137 
    138 
    139 if __name__ == '__main__':
    140     main()

    先开启visdom 进行可视化

    python -m visdom.server

    运行程序

    python LeNet-5_main.py

    打开浏览器查看live graph

    http://localhost:8097 

  • 相关阅读:
    Android MVP模式简单易懂的介绍方式 (一)
    Android studio如何和VS的region一样折叠代码
    Android的设置界面及Preference使用
    unity shader 剔除指定的颜色
    如何有效述职
    通俗易懂,什么是.NET?什么是.NET Framework?什么是.NET Core?
    通过输入命令行参数来控制程序
    unity 用代码控制动画的播放的进度
    Unity 连接WebSocket(ws://)服务器
    随笔
  • 原文地址:https://www.cnblogs.com/jeshy/p/10928315.html
Copyright © 2011-2022 走看看