zoukankan      html  css  js  c++  java
  • 深度学习与Pytorch入门实战(十二)实现ResNet-18并在Cifar-10数据集上进行验证

    ResNet图解

    nn.Module详解

    1. Pytorch上搭建ResNet-18

    1.1 ResNet block子模块

    import torch
    from torch import nn
    from torch.nn import functional as F
    
    
    class ResBlk(nn.Module):
        """
        ResNet block子模块
        """
        def __init__(self, ch_in, ch_out, stride = 1):
    #         super(ResBlk, self).__init__()  # python2写法
            # python3写法
            super().__init__()
            
            self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, 
                                   stride=stride, padding=1)
            self.bn1 = nn.BatchNorm2d(ch_out)
            self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, # 输出通道不变
                                  stride=1, padding=1)
            self.bn2 = nn.BatchNorm2d(ch_out)
            
            self.extra = nn.Sequential()
            # 如果输入和输出的通道不一致,或其步长不为 1,需要将二者转成一致
            if ch_out != ch_in:
                # 将x的维度[b, ch_in, h, w] => [b, ch_out, h, w]
                self.extra = nn.Sequential(
                    nn.Conv2d(ch_in, ch_out, kernel_size=1,  
                             stride=stride), 
                    nn.BatchNorm2d(ch_out)
                )
                
        def forward(self, x):
            
            out = F.relu(self.bn1(self.conv1(x)))
            out = self.bn2(self.conv2(out))
            
            out = self.extra(x) + out
            out = F.relu(out)
            return out
    

    1.2 ResNet18主模块

    class ResNet18(nn.Module):
        """
        主模块
        """
        def __init__(self):
            super(ResNet18, self).__init__()
            
            self.conv1 = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, stride=3, padding=0),
                nn.BatchNorm2d(64)
            )
            # followed 4 blocks
            self.blk1 = ResBlk(64, 128, stride=2)  # [b, 64, h, w] => [b, 128, h ,w]
            self.blk2 = ResBlk(128, 256, stride=2) # [b, 128, h, w] => [b, 256, h, w]
            self.blk3 = ResBlk(256, 512, stride=2) # [b, 256, h, w] => [b, 512, h, w]
            self.blk4 = ResBlk(512, 512, stride=2) # [b, 512, h, w] => [b, 512, h, w]
            
            self.outlayer = nn.Linear(512*1*1, 10) # 全连接层,总共10个分类
            
        def forward(self, x):
            x = F.relu(self.conv1(x))
            
            # [b, 64, h, w] => [b, 1024, h, w]
            x = self.blk1(x)
            x = self.blk2(x)
            x = self.blk3(x)
            x = self.blk4(x)
            
            # 之前的特征图尺寸为多少,只要设置为(1,1),那么最终特征图大小都为(1,1) 
            x = F.adaptive_avg_pool2d(x, [1,1])    # [b, 512, h, w] => [b, 512, 1, 1]
            # Flatten,将四维张量转换为二维张量之后,才能作为全连接层的输入
            x = x.view(x.size(0), -1)   
            # Full connected layer
            x = self.outlayer(x)
            
            return x  
    

    测试:

    blk = ResBlk(64, 128, stride=4)
    tmp = torch.randn(2, 64, 32, 32)
    out = blk(tmp)
    print('block:', out.shape)                # block: torch.Size([2, 128, 8, 8])
    
    x = torch.randn(2, 3, 32, 32)
    model = ResNet18()
    out = model(x)
    print('resnet:', out.shape)               # resnet: torch.Size([2, 10])
    
    block: torch.Size([2, 128, 8, 8])
    resnet: torch.Size([2, 10])
    

    2. 训练Cifar-10数据集

    • 所选数据集为Cifar-10,该数据集共有60000张带标签的彩色图像,这些图像尺寸32*32,分为10个类,每类6000张图。

    • 这里面有50000张用于训练,每个类5000张;另外10000用于测试,每个类1000张。

    import  torch
    from    torch.utils.data import DataLoader
    from    torchvision import datasets,transforms
    from    torch import nn, optim
    
    from    resnet import ResNet18
    
    
    def main():
        batchsz = 128
    
        # 训练集
        cifar_train = datasets.CIFAR10('cifar', train=True, download=True, 
                                       transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])
        ]))
        cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
    
    
        # 测试集
        cifar_test = datasets.CIFAR10('cifar', train=False, 
                                      transform=transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])
        ]))
        cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
    
    
        x, label = iter(cifar_train).next()
        # x: torch.Size([128, 3, 32, 32])  label: torch.Size([128])
        print('x:', x.shape, 'label:', label.shape)  
    
        # 定义模型-ResNet
        model = ResNet18()
    
        # 定义损失函数和优化方式
        criteon = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        print(model)
    
        # 训练网络
        for epoch in range(1000):
    
            model.train()                               # 训练模式
            for batchidx, (x, label) in enumerate(cifar_train):
                # x: [b, 3, 32, 32]
                # label: [b]
    
                logits = model(x)                       # logits: [b, 10]
                loss = criteon(logits, label)           # 标量
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
            print(epoch, 'loss:', loss.item())
    
    
            model.eval()                                # 测试模式
            with torch.no_grad():
    
                total_correct = 0                       # 预测正确的个数
                total_num = 0
                for x, label in cifar_test:
                    # x: [b, 3, 32, 32]
                    # label: [b]
    
                    logits = model(x)                   # [b, 10]
                    pred = logits.argmax(dim=1)         # [b]
    
                    # [b] vs [b] => scalar tensor
                    correct = torch.eq(pred, label).float().sum().item()
                    total_correct += correct
                    total_num += x.size(0)
    
                acc = total_correct / total_num
                print(epoch, 'test acc:', acc)
    
    
    if __name__ == '__main__':
        main()
    
    • transforms.Normalize:逐channel的对图像进行标准化

      • output = (input - mean) / std

      • mean: 各通道的均值;std:各通道的标准差;inplace:是否原地操作

    • torch.no_grad(): 是一个上下文管理器,被该语句 wrap 起来的部分将不会 track 梯度。

    • 同时 torch.no_grad() 还可以作为一个装饰器。

    • 比如,在网络测试的函数前加上

    @torch.no_grad()
    def eval():
    	...
    

    太慢了,只训练一个epoch

    view code
    Files already downloaded and verified
    x: torch.Size([128, 3, 32, 32]) label: torch.Size([128])
    ResNet18(
      (conv1): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(3, 3))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (blk1): ResBlk(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (extra): Sequential(
          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (blk2): ResBlk(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (extra): Sequential(
          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (blk3): ResBlk(
        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (extra): Sequential(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (blk4): ResBlk(
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (extra): Sequential()
      )
      (outlayer): Linear(in_features=512, out_features=10, bias=True)
    )
    0 loss: 1.0541729927062988
    0 test acc: 0.5873
    
  • 相关阅读:
    sencha touch 视图(view) activate与deactivate事件探讨
    sencha touch Demo(示例)(2014-6-25)
    sencha touch NavigationView 源码详解(注释)
    sencha touch Model validations(模型验证,自定义验证)
    sencha touch routes(路由) 传递中文参数
    第二步 使用Cordova 3.0(及以上版本) 创建安卓项目(2014-6-25)
    sencha touch datepicker/datepickerfield(时间选择控件)扩展(废弃 仅参考)
    sencha touch list ListPaging使用详解
    sencha touch list更新单行数据
    sencha touch list tpl 监听组件插件(2013-9-15)
  • 原文地址:https://www.cnblogs.com/douzujun/p/13361198.html
Copyright © 2011-2022 走看看