zoukankan      html  css  js  c++  java
  • Pytorch-实现ResNet-18并在Cifar-10数据集上进行验证

    1.Pytorch上搭建ResNet-18

     1 import  torch
     2 from    torch import  nn
     3 from    torch.nn import functional as F
     4 
     5 
     6 class ResBlk(nn.Module):
     7     """
     8     resnet block子模块
     9     """
    10     def __init__(self, ch_in, ch_out, stride=1):
    11    
    12         super(ResBlk, self).__init__()
    13 
    14         self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
    15         self.bn1 = nn.BatchNorm2d(ch_out)
    16         self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
    17         self.bn2 = nn.BatchNorm2d(ch_out)
    18 
    19         self.extra = nn.Sequential()
    20         # 如果输入和输出的通道不一致,或其步长不为 1,需要将二者转成一致
    21         if ch_out != ch_in:
    22             # [b, ch_in, h, w] => [b, ch_out, h, w]
    23             self.extra = nn.Sequential(
    24                 nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
    25                 nn.BatchNorm2d(ch_out)
    26             )
    27 
    28     def forward(self, x):
    29         
    30         out = F.relu(self.bn1(self.conv1(x)))
    31         out = self.bn2(self.conv2(out))
    32         
    33         out = self.extra(x) + out
    34         out = F.relu(out)        
    35         return out
    36 
    37 
    38 class ResNet18(nn.Module):
    39     '''
    40     主模块
    41     '''
    42     def __init__(self):
    43         super(ResNet18, self).__init__()
    44 
    45         self.conv1 = nn.Sequential(
    46             nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
    47             nn.BatchNorm2d(64)
    48         )
    49         # followed 4 blocks       
    50         self.blk1 = ResBlk(64, 128, stride=2)          #[b, 64, h, w] => [b, 128, h ,w]
    51         self.blk2 = ResBlk(128, 256, stride=2)         #[b, 128, h, w] => [b, 256, h, w]
    52         self.blk3 = ResBlk(256, 512, stride=2)         #[b, 256, h, w] => [b, 512, h, w]
    53         self.blk4 = ResBlk(512, 512, stride=2)         #[b, 512, h, w] => [b, 512, h, w]
    54 
    55         self.outlayer = nn.Linear(512*1*1, 10)         #全连接层,总共10个分类
    56 
    57     def forward(self, x):
    58         x = F.relu(self.conv1(x))
    59 
    60         # [b, 64, h, w] => [b, 1024, h, w]
    61         x = self.blk1(x)
    62         x = self.blk2(x)
    63         x = self.blk3(x)
    64         x = self.blk4(x)
    65 
    66         x = F.adaptive_avg_pool2d(x, [1, 1])          #[b, 512, h, w] => [b, 512, 1, 1]
    67         x = x.view(x.size(0), -1)
    68         x = self.outlayer(x)
    69 
    70         return x

    举个栗子测试一下:

     1 if __name__ == '__main__':
     2 
     3     blk = ResBlk(64, 128, stride=4)
     4     tmp = torch.randn(2, 64, 32, 32)
     5     out = blk(tmp)
     6     print('block:', out.shape)                      #block: torch.Size([2, 128, 8, 8])
     7 
     8     x = torch.randn(2, 3, 32, 32)
     9     model = ResNet18()
    10     out = model(x)
    11     print('resnet:', out.shape)                     #resnet: torch.Size([2, 10])

    2.训练Cifar-10数据集

    所选数据集为Cifar-10,该数据集共有60000张带标签的彩色图像,这些图像尺寸32*32,分为10个类,每类6000张图。这里面有50000张用于训练,每个类5000张,另外10000用于测试,每个类1000张。

     1 import  torch
     2 from    torch.utils.data import DataLoader
     3 from    torchvision import datasets,transforms 
     4 from    torch import nn, optim
     5 
     6 from    resnet import ResNet18
     7 
     8 
     9 def main():
    10     batchsz = 128
    11     
    12     #训练集
    13     cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
    14         transforms.Resize((32, 32)),
    15         transforms.ToTensor(),
    16         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    17     ]))
    18     cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    19     
    20     
    21     #测试集
    22     cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
    23         transforms.Resize((32, 32)),
    24         transforms.ToTensor(),
    25         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    26     ]))
    27     cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
    28 
    29 
    30     x, label = iter(cifar_train).next()
    31     print('x:', x.shape, 'label:', label.shape)         #x: torch.Size([128, 3, 32, 32])  label: torch.Size([128])
    32     
    33     #定义模型-ResNet
    34     model = ResNet18()
    35     
    36     #定义损失函数和优化方式
    37     criteon = nn.CrossEntropyLoss()
    38     optimizer = optim.Adam(model.parameters(), lr=1e-3)
    39     print(model)
    40     
    41     #训练网络
    42     for epoch in range(1000):
    43 
    44         model.train()                                   #训练模式
    45         for batchidx, (x, label) in enumerate(cifar_train):
    46             #x: [b, 3, 32, 32]
    47             #label: [b]
    48             
    49             logits = model(x)                           #logits: [b, 10]
    50             loss = criteon(logits, label)               #标量
    51 
    52             optimizer.zero_grad()
    53             loss.backward()
    54             optimizer.step()
    55 
    56         print(epoch, 'loss:', loss.item())
    57 
    58 
    59         model.eval()                                    #测试模式
    60         with torch.no_grad():
    61             
    62             total_correct = 0                           #预测正确的个数
    63             total_num = 0
    64             for x, label in cifar_test:
    65                 #x: [b, 3, 32, 32]
    66                 #label: [b]
    67                               
    68                 logits = model(x)                       #[b, 10]            
    69                 pred = logits.argmax(dim=1)             #[b]
    70                 
    71                 # [b] vs [b] => scalar tensor
    72                 correct = torch.eq(pred, label).float().sum().item()
    73                 total_correct += correct
    74                 total_num += x.size(0)
    75                 
    76             acc = total_correct / total_num
    77             print(epoch, 'test acc:', acc)
    78 
    79 
    80 if __name__ == '__main__':
    81     main()

    迭代1000次,训练太久了,暂且输出前5次。

    0 loss: 1.0912220478057861
    0 test acc: 0.5583
    1 loss: 0.8604468107223511
    1 test acc: 0.6592
    2 loss: 0.6625195145606995
    2 test acc: 0.6827
    3 loss: 0.7064175009727478
    3 test acc: 0.6904
    4 loss: 0.5687283277511597
    4 test acc: 0.7059

  • 相关阅读:
    chrome jsonView插件安装
    Android之父Andy Rubin:被乔布斯羡慕嫉妒的天才
    一张图看懂苹果MacBook所有屏幕分辨率
    Mac如何让调整窗口大小更简单
    OS X快捷键小技巧
    magent编译安装及常见错误
    【STL】算法 — partial_sort
    Lucene 4.4 依据Int类型字段删除索引
    简易实现 TextView单行文本水平触摸滑动效果
    cocos2d js 怎样动态载入外部图片
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13332640.html
Copyright © 2011-2022 走看看