zoukankan      html  css  js  c++  java
  • 龙良曲pytorch学习笔记_LeNet5

    LeNet5网络和CIFAR10数据集

    main函数--dataloader--train--test

     1 import torch
     2 from torch.utils.data import DataLoader
     3 from torchvision import datasets
     4 from torchvision import transforms
     5 from torch import nn,optim
     6 from lenet5 import LeNet5
     7 
     8 def main():
     9     batch_size = 32
    10     cifar_train = datasets.CIFAR10('cifar',train = True,transform = transforms.Compose([
    11         transforms.Resize((32,32)),
    12         transforms.ToTensor()
    13     ]),download = True)
    14     
    15     # 可以同时加载多张图片
    16     cifar_train = DataLoader(cifar_train,batch_size = batch_size,shuffle = True)
    17     
    18     cifar_test = datasets.CIFAR10('cifar',train = False,transform = transforms.Compose([
    19         transforms.Resize((32,32)),
    20         transforms.ToTensor()
    21     ]),download = True)
    22     
    23     # 可以同时加载多张图片
    24     cifar_test = DataLoader(cifar_test,batch_size = batch_size,shuffle = True)
    25 
    26     # 数据加载成功后可以检验shape
    27     x,label = iter(cifar_train).next()
    28     print('x:',x.shape,'label:',label.shape)
    29 
    30     device = torch.device('cuda')
    31     model = LeNet5().to(device)
    32     criteon = nn.CrossEntropyLoss().to(device)
    33     optimizer = optim.Adam(model.parameters(),lr=1e-3)
    34     
    35     print(model)
    36     
    37     for epoch in range(1000):
    38         
    39         model.train()
    40         for batchidx,(x,label) in enumerate(cifar_train):
    41             # x: [b,3,32,32], label: [b]
    42             x,label = x.to(device),label.to(device)
    43             
    44             logits = model(x)
    45             # logits:[b,10]
    46             # label:[b]
    47             loss = criteon(logits,label)
    48             
    49             # backprop
    50             optimizer.zero_grad()
    51             loss.backwark()
    52             optimizer.step()
    53             
    54         #
    55         print(epoch,loss.item())
    56         
    57         model.eval()
    58         # 不需要做梯度相关计算
    59         with torch.nn_grad():
    60             # test
    61             total_correct = 0
    62             total_num = 0
    63             for x,label in cifar_test:
    64                 x,label = x.to(device),label.to(device)
    65                 # logits:[b,10]
    66                 logits = model(x)
    67                 pred = logits.argmax(dim=1)
    68                 # 获取一个batch的在累加
    69                 total_correct = += torch.eq(pred,label).float().sum().item()
    70                 # x.size(0)就是batch_size
    71                 total_num += x.size(0)
    72                 
    73             acc = total_correct / total_num
    74             print(epoch,acc)
    75             
    76 if __name__ == '__main__'
    77     main()

    LeNet网络--tmp测试

     1 import torch
     2 from torch import nn
     3 from torch.nn import functional as F
     4 
     5 class LeNet5(nn.Module):
     6     """
     7     for cifar10 dataset.
     8     """
     9     def __init__(self):
    10         super(LeNet5,self).__init__()
    11         
    12         self.conv_unit = nn.Sequential(
    13             # x:[b,3,32,32] --> [b,6,]
    14             # input_channel,output_channel,kernel_size,stride,padding
    15             nn.Conv2d(3,6,kernel_size = 5,stride = 1,padding = 0),
    16             nn.AvgPool2d(kernel_size = 2,stride = 2,padding = 0),
    17             # 
    18             nn.Conv2d(6,16,kernel_size = 5,stride = 1,padding = 0),
    19             nn.AvgPool2d(kernel_size = 2,stride = 2,padding = 0),
    20         )
    21         # Flatten
    22         # fc_unit
    23         self.fc_unit = nn.Sequential(
    24             # 由下面的测试得出来的
    25             nn.Linear(16*5*5,120),
    26             # 全连接层会出现梯度离散现象,加一个relu
    27             nn.ReLU(),
    28             nn.Linear(120,84),
    29             nn.ReLU(),
    30             nn.Linear(84,10),
    31         )
    32         '''
    33         tmp = torch.randn(2,3,32,32)
    34         out = self.conv_unit(tmp)
    35         # 测试一下输出的维度,用于全连接层
    36         # [2,16,5,5]
    37         print('conv_out:',out.shape)
    38         '''
    39         
    40         # use Cross Entropy Loss
    41         # 放到类外,不用引入y参数
    42         # self.criteon = nn.CrossEntropyLoss()
    43         
    44         # 从左往右走的,backward会自动根据这个走
    45     def forward(self,x):
    46         # 取得x的shape,然后0号为batch_size
    47         batch_size = x.size(0)
    48         # [b,3,32,32] --> [b,16,5,5]
    49         x = self.conv_unit(x)
    50         # [b,16,5,5] --> [b,16*5*5]
    51         x = x.view(batch_size,16*5*5)
    52         # [b,16*5*5] --> [b,10]
    53         logits = self.fc_unit(x)
    54         return logits
    55         # [b,10] crossEntropy会包含,不用写
    56         # pred = F.softmax(logits,dim = 1)
    57         # loss = self.criteon(logits,y)
    58         
    59         
    60 def main():
    61 
    62     net = LeNet5()
    63     tmp = torch.randn(2,3,32,32)
    64     out = net(tmp)
    65     print('lenet_out:',out.shape)
    66     
    67     
    68 if __name__ == '__main__'
    69     main()
  • 相关阅读:
    王健林:在中国远离政府太假了 期望王思聪稳重
    科目二很难考吗?经验全在这里!
    HTTP 的长连接和短连接
    JS中实现字符串和数组的相互转化
    Maven介绍,包括作用、核心概念、用法、常用命令、扩展及配置
    kafka数据可靠性深度解读
    深入浅出JMS(二)--ActiveMQ简单介绍以及安装
    ActiveMQ入门实例
    activemq的几种基本通信方式总结
    mysql按年度、季度、月度、周、日SQL统计查询
  • 原文地址:https://www.cnblogs.com/fxw-learning/p/12317504.html
Copyright © 2011-2022 走看看