zoukankan      html  css  js  c++  java
  • 学习笔记1 pytorch实现Lenet5

    参考资料:https://cuijiahua.com/blog/2018/01/dl_3.html

    代码实现: 

     1 import torch
     2 from torch import nn
     3 from torch.nn import functional as F
     4 class Lenet5(nn.Module):
     5     """
     6     for cifar10 dataset
     7     """
     8     def __init__(self):
     9         super(Lenet5, self).__init__()
    10         self.conv_unit=nn.Sequential(
    11             #卷积层 x:[b,3,32,32] =>[b,6,30,30]
    12             nn.Conv2d(3,6,kernel_size=3,stride=1,padding=0),
    13             #池化层 [b,6,30,30]=>[b,6,15,15]
    14             nn.MaxPool2d(kernel_size=2,stride=2,padding=0),
    15             #卷积层 [b,6,15,15] =>[b,16,13,13]
    16             nn.Conv2d(6,16,kernel_size=3,stride=1,padding=0),
    17             #池化层 [b,16,13,13]=>[b,16,6,6]
    18             nn.MaxPool2d(kernel_size=2,stride=2,padding=0)
    19         )
    20         #flatten
    21         #全连接层
    22         self.fc_unit=nn.Sequential(
    23             nn.Linear(16*6*6,120),
    24             nn.ReLU(inplace=True),
    25             nn.Linear(120,84),
    26             nn.ReLU(inplace=True),
    27             nn.Linear(84,10)
    28         )
    29 
    30         #use Cross Entropy loss
    31         #self.criteon=nn.CrossEntropyLoss()
    32 
    33     def forward(self,x):
    34         """
    35         x:[b,3,32,32]
    36         :param x:
    37         :return:
    38         """
    39         batchsz=x.size(0)
    40         #[b,3,32,32]=>[b,16,6,6]
    41         x=self.conv_unit(x)
    42         #[b,16,6,6] => [b,16*6*6]
    43         x=x.view(batchsz,16*6*6)
    44         #[b,16*6*6]=>[b,10]
    45         logits=self.fc_unit(x)
    46 
    47         return logits
    48         #
    49     
    50 #测试函数
    51 def main():
    52     net=Lenet5()
    53     tmp = torch.randn(2, 3, 32, 32)
    54     #通过测试,确定conv_unit输入维度
    55     out = net(tmp)
    56     print('conv out:', out.shape)
    57 if __name__=='__main__':
    58     main()
     1 import  torch
     2 from    torch.utils.data import DataLoader
     3 from    torchvision import datasets
     4 from    torchvision import transforms
     5 from    torch import nn, optim
     6 
     7 from    lenet5 import Lenet5
     8 
     9 def main():
    10     #batchsize大小
    11     batchsz = 32
    12     #读取一张图片并进行数据增强
    13     cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
    14         transforms.Resize((32, 32)),
    15         transforms.ToTensor()
    16     ]), download=True)
    17     #读取bathsize大小的一批图像
    18     cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    19 
    20     cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
    21         transforms.Resize((32, 32)),
    22         transforms.ToTensor()
    23     ]), download=True)
    24     cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
    25 
    26     #测试图像读取是否正确
    27     #x, label = iter(cifar_train).next()
    28     #print('x:', x.shape, 'label:', label.shape)
    29 
    30     #选用gpu设备
    31     device = torch.device('cuda:0,1')
    32     model = Lenet5().to(device)
    33     #model = ResNet18().to(device)
    34     
    35     #定义交叉熵评估损失值
    36     criteon = nn.CrossEntropyLoss().to(device)
    37     #定义优化器
    38     optimizer = optim.Adam(model.parameters(), lr=1e-3)
    39     #打印模型信息
    40     print(model)
    41 
    42     #train
    43     for epoch in range(1000):
    44         #model.train()
    45         for batchidx, (x, label) in enumerate(cifar_train):
    46             # [b, 3, 32, 32]
    47             # [b]
    48             x, label = x.to(device), label.to(device)
    49             
    50             #预测值
    51             logits = model(x)
    52             # logits: [b, 10]
    53             # label:  [b]
    54             # loss: tensor scalar
    55             loss = criteon(logits, label)
    56 
    57             # 反向传播
    58             optimizer.zero_grad()
    59             loss.backward()
    60             optimizer.step()
    61 
    62         print(epoch, 'loss:', loss.item())
    63         #test
    64         model.eval()
    65         with torch.no_grad():
    66             # test
    67             total_correct = 0
    68             total_num = 0
    69             for x, label in cifar_test:
    70                 # [b, 3, 32, 32]
    71                 # [b]
    72                 x, label = x.to(device), label.to(device)
    73 
    74                 # [b, 10]
    75                 logits = model(x)
    76                 # [b] 选取所有类别得分的最大值作为预测类
    77                 pred = logits.argmax(dim=1)
    78                 # [b] vs [b] => scalar tensor
    79                 #计算每个batch正确的预测数量
    80                 correct = torch.eq(pred, label).float().sum().item()
    81                 total_correct += correct
    82                 total_num += x.size(0)
    83             acc = total_correct / total_num
    84             print(epoch, 'acc:', acc)
    85 
    86 if __name__ == '__main__':
    87     main()

    结果:

     

  • 相关阅读:
    "Java:comp/env/"讲解与JNDI
    table的td去边框
    jsp获取所有参数
    spring-mvc设置首页
    jdbc数据库连接方式
    文件上传
    SMBMS
    过滤器和监听器
    解决Maven的JDK版本问题
    MVC
  • 原文地址:https://www.cnblogs.com/sclu/p/11947643.html
Copyright © 2011-2022 走看看