zoukankan      html  css  js  c++  java
  • ResNet网络的Pytorch实现

    1.文章原文地址

    Deep Residual Learning for  Image Recognition

    2.文章摘要

           神经网络的层次越深越难训练。我们提出了一个残差学习框架来简化网络的训练,这些网络比之前使用的网络都要深的多。我们明确地将层变为学习关于层输入的残差函数,而不是学习未参考的函数。我们提供了综合的实验证据来表明这个残差网络更容易优化,以及通过极大提升网络深度可以获得更好的准确率。在ImageNet数据集上,我们评估了残差网络,该网络有152层,层数是VGG网络的8倍,但是有更低的复杂度。几个残差网络的集成在ImageNet数据集上取得了3.57%错误率。这个结果在ILSVRC2015分类任务上取得第一名的成绩。我们也使用了100和1000层网络用在了数据集CIFAR-10上加以分析。

           在许多视觉识别任务中,表征的深度是至关重要的。仅仅通过极端深的表征,我们在COCO目标检测数据集上得到了28%的相对提高。深度残差网络是我们提交到ILSVRC & COCO2015竞赛的网络基础,在这里我们获得了ImageNet检测任务、ImageNet定位任务,COCO检测任务和COCO分割任务的第一名。

    3.网络结构

    4.Pytorch实现

      1 import torch.nn as nn
      2 from  torch.utils.model_zoo import load_url as load_state_dict_from_url
      3 
      4 
      5 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
      6            'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
      7 
      8 
      9 model_urls = {
     10     'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
     11     'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
     12     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
     13     'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
     14     'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
     15 }
     16 
     17 
     18 def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
     19     """3x3 convolution with padding"""
     20     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
     21                      padding=dilation, groups=groups, bias=False, dilation=dilation)
     22 
     23 
     24 def conv1x1(in_planes, out_planes, stride=1):
     25     """1x1 convolution"""
     26     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
     27 
     28 
     29 class BasicBlock(nn.Module):
     30     expansion = 1
     31 
     32     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
     33                  base_width=64, dilation=1, norm_layer=None):
     34         super(BasicBlock, self).__init__()
     35         if norm_layer is None:
     36             norm_layer = nn.BatchNorm2d
     37         if groups != 1 or base_width != 64:
     38             raise ValueError('BasicBlock only supports groups=1 and base_width=64')
     39         if dilation > 1:
     40             raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
     41         # Both self.conv1 and self.downsample layers downsample the input when stride != 1
     42         self.conv1 = conv3x3(inplanes, planes, stride)
     43         self.bn1 = norm_layer(planes)
     44         self.relu = nn.ReLU(inplace=True)
     45         self.conv2 = conv3x3(planes, planes)
     46         self.bn2 = norm_layer(planes)
     47         self.downsample = downsample
     48         self.stride = stride
     49 
     50     def forward(self, x):
     51         identity = x
     52 
     53         out = self.conv1(x)
     54         out = self.bn1(out)
     55         out = self.relu(out)
     56 
     57         out = self.conv2(out)
     58         out = self.bn2(out)
     59 
     60         if self.downsample is not None:
     61             identity = self.downsample(x)
     62 
     63         out += identity
     64         out = self.relu(out)
     65 
     66         return out
     67 
     68 
     69 class Bottleneck(nn.Module):
     70     expansion = 4
     71 
     72     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
     73                  base_width=64, dilation=1, norm_layer=None):
     74         super(Bottleneck, self).__init__()
     75         if norm_layer is None:
     76             norm_layer = nn.BatchNorm2d
     77         width = int(planes * (base_width / 64.)) * groups
     78         # Both self.conv2 and self.downsample layers downsample the input when stride != 1
     79         self.conv1 = conv1x1(inplanes, width)
     80         self.bn1 = norm_layer(width)
     81         self.conv2 = conv3x3(width, width, stride, groups, dilation)
     82         self.bn2 = norm_layer(width)
     83         self.conv3 = conv1x1(width, planes * self.expansion)
     84         self.bn3 = norm_layer(planes * self.expansion)
     85         self.relu = nn.ReLU(inplace=True)
     86         self.downsample = downsample
     87         self.stride = stride
     88 
     89     def forward(self, x):
     90         identity = x
     91 
     92         out = self.conv1(x)
     93         out = self.bn1(out)
     94         out = self.relu(out)
     95 
     96         out = self.conv2(out)
     97         out = self.bn2(out)
     98         out = self.relu(out)
     99 
    100         out = self.conv3(out)
    101         out = self.bn3(out)
    102 
    103         if self.downsample is not None:
    104             identity = self.downsample(x)
    105 
    106         out += identity
    107         out = self.relu(out)
    108 
    109         return out
    110 
    111 
    112 class ResNet(nn.Module):
    113 
    114     def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
    115                  groups=1, width_per_group=64, replace_stride_with_dilation=None,
    116                  norm_layer=None):
    117         super(ResNet, self).__init__()
    118         if norm_layer is None:
    119             norm_layer = nn.BatchNorm2d
    120         self._norm_layer = norm_layer
    121 
    122         self.inplanes = 64
    123         self.dilation = 1
    124         if replace_stride_with_dilation is None:
    125             # each element in the tuple indicates if we should replace
    126             # the 2x2 stride with a dilated convolution instead
    127             replace_stride_with_dilation = [False, False, False]
    128         if len(replace_stride_with_dilation) != 3:
    129             raise ValueError("replace_stride_with_dilation should be None "
    130                              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
    131         self.groups = groups
    132         self.base_width = width_per_group
    133         self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
    134                                bias=False)
    135         self.bn1 = norm_layer(self.inplanes)
    136         self.relu = nn.ReLU(inplace=True)
    137         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    138         self.layer1 = self._make_layer(block, 64, layers[0])
    139         self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
    140                                        dilate=replace_stride_with_dilation[0])
    141         self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
    142                                        dilate=replace_stride_with_dilation[1])
    143         self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
    144                                        dilate=replace_stride_with_dilation[2])
    145         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    146         self.fc = nn.Linear(512 * block.expansion, num_classes)
    147 
    148         for m in self.modules():
    149             if isinstance(m, nn.Conv2d):
    150                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    151             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
    152                 nn.init.constant_(m.weight, 1)
    153                 nn.init.constant_(m.bias, 0)
    154 
    155         # Zero-initialize the last BN in each residual branch,
    156         # so that the residual branch starts with zeros, and each residual block behaves like an identity.
    157         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
    158         if zero_init_residual:
    159             for m in self.modules():
    160                 if isinstance(m, Bottleneck):
    161                     nn.init.constant_(m.bn3.weight, 0)
    162                 elif isinstance(m, BasicBlock):
    163                     nn.init.constant_(m.bn2.weight, 0)
    164 
    165     def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    166         norm_layer = self._norm_layer
    167         downsample = None
    168         previous_dilation = self.dilation
    169         if dilate:
    170             self.dilation *= stride
    171             stride = 1
    172         if stride != 1 or self.inplanes != planes * block.expansion:
    173             downsample = nn.Sequential(
    174                 conv1x1(self.inplanes, planes * block.expansion, stride),
    175                 norm_layer(planes * block.expansion),
    176             )
    177 
    178         layers = []
    179         layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
    180                             self.base_width, previous_dilation, norm_layer))
    181         self.inplanes = planes * block.expansion
    182         for _ in range(1, blocks):
    183             layers.append(block(self.inplanes, planes, groups=self.groups,
    184                                 base_width=self.base_width, dilation=self.dilation,
    185                                 norm_layer=norm_layer))
    186 
    187         return nn.Sequential(*layers)
    188 
    189     def forward(self, x):
    190         x = self.conv1(x)
    191         x = self.bn1(x)
    192         x = self.relu(x)
    193         x = self.maxpool(x)
    194 
    195         x = self.layer1(x)
    196         x = self.layer2(x)
    197         x = self.layer3(x)
    198         x = self.layer4(x)
    199 
    200         x = self.avgpool(x)
    201         x = x.reshape(x.size(0), -1)
    202         x = self.fc(x)
    203 
    204         return x
    205 
    206 
    207 def _resnet(arch, inplanes, planes, pretrained, progress, **kwargs):
    208     model = ResNet(inplanes, planes, **kwargs)
    209     if pretrained:
    210         state_dict = load_state_dict_from_url(model_urls[arch],
    211                                               progress=progress)
    212         model.load_state_dict(state_dict)
    213     return model
    214 
    215 
    216 def resnet18(pretrained=False, progress=True, **kwargs):
    217     """Constructs a ResNet-18 model.
    218     Args:
    219         pretrained (bool): If True, returns a model pre-trained on ImageNet
    220         progress (bool): If True, displays a progress bar of the download to stderr
    221     """
    222     return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
    223                    **kwargs)
    224 
    225 
    226 def resnet34(pretrained=False, progress=True, **kwargs):
    227     """Constructs a ResNet-34 model.
    228     Args:
    229         pretrained (bool): If True, returns a model pre-trained on ImageNet
    230         progress (bool): If True, displays a progress bar of the download to stderr
    231     """
    232     return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
    233                    **kwargs)
    234 
    235 
    236 def resnet50(pretrained=False, progress=True, **kwargs):
    237     """Constructs a ResNet-50 model.
    238     Args:
    239         pretrained (bool): If True, returns a model pre-trained on ImageNet
    240         progress (bool): If True, displays a progress bar of the download to stderr
    241     """
    242     return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
    243                    **kwargs)
    244 
    245 
    246 def resnet101(pretrained=False, progress=True, **kwargs):
    247     """Constructs a ResNet-101 model.
    248     Args:
    249         pretrained (bool): If True, returns a model pre-trained on ImageNet
    250         progress (bool): If True, displays a progress bar of the download to stderr
    251     """
    252     return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
    253                    **kwargs)
    254 
    255 
    256 def resnet152(pretrained=False, progress=True, **kwargs):
    257     """Constructs a ResNet-152 model.
    258     Args:
    259         pretrained (bool): If True, returns a model pre-trained on ImageNet
    260         progress (bool): If True, displays a progress bar of the download to stderr
    261     """
    262     return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
    263                    **kwargs)
    264 
    265 
    266 def resnext50_32x4d(**kwargs):
    267     kwargs['groups'] = 32
    268     kwargs['width_per_group'] = 4
    269     return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
    270                    pretrained=False, progress=True, **kwargs)
    271 
    272 
    273 def resnext101_32x8d(**kwargs):
    274     kwargs['groups'] = 32
    275     kwargs['width_per_group'] = 8
    276     return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
    277                    pretrained=False, progress=True, **kwargs)

    参考

    https://github.com/pytorch/vision/tree/master/torchvision/models

  • 相关阅读:
    自动对一个文件夹下的N个word文件批量执行一个宏
    PHP正则匹配联系方式手机号、QQ、微信、邮箱、固定电话
    私信基本功能数据库设计
    ArcGIS三分式标注、四分式标注和同时上下标实现
    Word2019文档中将页面边框更改为文本边框的方法
    Arcgis彻底删除和卸载
    ArcMap中各种基本概念的介绍
    ArcGIS Python工具箱.pyt裁剪工具
    C# Object对象的ToString方法在转换日期时丢失毫秒
    2020年糖尿病领域中国学者学术影响力排名
  • 原文地址:https://www.cnblogs.com/ys99/p/10872262.html
Copyright © 2011-2022 走看看