# ====================LeNet-5_main.py===============
# pytorch+torchvision+visdom
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Sun May 26 22:53:52 2019 4 5 @author: jiangshan 6 """ 7 #A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset. 8 import torch 9 import torch.nn as nn 10 import torch.optim as optim 11 from torchvision.datasets.mnist import MNIST 12 import torchvision.transforms as transforms 13 from torch.utils.data import DataLoader 14 import visdom 15 from collections import OrderedDict 16 17 class LeNet5(nn.Module): 18 """ 19 Input - 1x32x32 20 C1 - 6@28x28 (5x5 kernel) 21 relu 22 S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling 23 C3 - 16@10x10 (5x5 kernel, complicated shit) 24 relu 25 S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling 26 C5 - 120@1x1 (5x5 kernel) 27 F6 - 84 28 relu 29 F7 - 10 (Output) 30 """ 31 def __init__(self): 32 super(LeNet5, self).__init__() 33 34 self.convnet = nn.Sequential(OrderedDict([ 35 ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))), 36 ('relu1', nn.ReLU()), 37 ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)), 38 ('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))), 39 ('relu3', nn.ReLU()), 40 ('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)), 41 ('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))), 42 ('relu5', nn.ReLU()) 43 ])) 44 45 self.fc = nn.Sequential(OrderedDict([ 46 ('f6', nn.Linear(120, 84)), 47 ('relu6', nn.ReLU()), 48 ('f7', nn.Linear(84, 10)), 49 ('sig7', nn.LogSoftmax(dim=-1)) 50 ])) 51 52 def forward(self, img): 53 output = self.convnet(img) 54 output = output.view(img.size(0), -1) 55 output = self.fc(output) 56 return output 57 58 59 viz = visdom.Visdom() 60 data_train = MNIST('./data/mnist', 61 download=True, 62 transform=transforms.Compose([ 63 transforms.Resize((32, 32)), 64 transforms.ToTensor()])) 65 data_test = MNIST('./data/mnist', 66 train=False, 67 download=True, 68 transform=transforms.Compose([ 69 transforms.Resize((32, 32)), 70 transforms.ToTensor()])) 71 data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8) 72 data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8) 73 74 net = LeNet5() 75 criterion = nn.CrossEntropyLoss() 76 optimizer = optim.Adam(net.parameters(), lr=2e-3) 77 78 cur_batch_win = None 79 cur_batch_win_opts = { 80 'title': 'Epoch Loss Trace', 81 'xlabel': 'Batch Number', 82 'ylabel': 'Loss', 83 'width': 1200, 84 'height': 600, 85 } 86 87 88 def train(epoch): 89 global cur_batch_win 90 net.train() 91 loss_list, batch_list = [], [] 92 for i, (images, labels) in enumerate(data_train_loader): 93 optimizer.zero_grad() 94 95 output = net(images) 96 97 loss = criterion(output, labels) 98 99 loss_list.append(loss.detach().cpu().item()) 100 batch_list.append(i+1) 101 102 if i % 10 == 0: 103 print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item())) 104 105 # Update Visualization 106 if viz.check_connection(): 107 cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list), 108 win=cur_batch_win, name='current_batch_loss', 109 update=(None if cur_batch_win is None else 'replace'), 110 opts=cur_batch_win_opts) 111 loss.backward() 112 optimizer.step() 113 114 115 def test(): 116 net.eval() 117 total_correct = 0 118 avg_loss = 0.0 119 for i, (images, labels) in enumerate(data_test_loader): 120 output = net(images) 121 avg_loss += criterion(output, labels).sum() 122 pred = output.detach().max(1)[1] 123 total_correct += pred.eq(labels.view_as(pred)).sum() 124 125 avg_loss /= len(data_test) 126 print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test))) 127 128 129 def train_and_test(epoch): 130 train(epoch) 131 test() 132 133 134 def main(): 135 for e in range(1, 16): 136 train_and_test(e) 137 138 139 if __name__ == '__main__': 140 main()
先开启visdom 进行可视化
python -m visdom.server
运行程序
python LeNet-5_main.py
打开浏览器查看live graph