zoukankan      html  css  js  c++  java
  • ResNet网络的Pytorch实现

    1.文章原文地址

    Deep Residual Learning for  Image Recognition

    2.文章摘要

           神经网络的层次越深越难训练。我们提出了一个残差学习框架来简化网络的训练,这些网络比之前使用的网络都要深的多。我们明确地将层变为学习关于层输入的残差函数,而不是学习未参考的函数。我们提供了综合的实验证据来表明这个残差网络更容易优化,以及通过极大提升网络深度可以获得更好的准确率。在ImageNet数据集上,我们评估了残差网络,该网络有152层,层数是VGG网络的8倍,但是有更低的复杂度。几个残差网络的集成在ImageNet数据集上取得了3.57%错误率。这个结果在ILSVRC2015分类任务上取得第一名的成绩。我们也使用了100和1000层网络用在了数据集CIFAR-10上加以分析。

           在许多视觉识别任务中,表征的深度是至关重要的。仅仅通过极端深的表征,我们在COCO目标检测数据集上得到了28%的相对提高。深度残差网络是我们提交到ILSVRC & COCO2015竞赛的网络基础,在这里我们获得了ImageNet检测任务、ImageNet定位任务,COCO检测任务和COCO分割任务的第一名。

    3.网络结构

    4.Pytorch实现

      1 import torch.nn as nn
      2 from  torch.utils.model_zoo import load_url as load_state_dict_from_url
      3 
      4 
      5 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
      6            'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
      7 
      8 
      9 model_urls = {
     10     'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
     11     'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
     12     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
     13     'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
     14     'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
     15 }
     16 
     17 
     18 def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
     19     """3x3 convolution with padding"""
     20     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
     21                      padding=dilation, groups=groups, bias=False, dilation=dilation)
     22 
     23 
     24 def conv1x1(in_planes, out_planes, stride=1):
     25     """1x1 convolution"""
     26     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
     27 
     28 
     29 class BasicBlock(nn.Module):
     30     expansion = 1
     31 
     32     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
     33                  base_width=64, dilation=1, norm_layer=None):
     34         super(BasicBlock, self).__init__()
     35         if norm_layer is None:
     36             norm_layer = nn.BatchNorm2d
     37         if groups != 1 or base_width != 64:
     38             raise ValueError('BasicBlock only supports groups=1 and base_width=64')
     39         if dilation > 1:
     40             raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
     41         # Both self.conv1 and self.downsample layers downsample the input when stride != 1
     42         self.conv1 = conv3x3(inplanes, planes, stride)
     43         self.bn1 = norm_layer(planes)
     44         self.relu = nn.ReLU(inplace=True)
     45         self.conv2 = conv3x3(planes, planes)
     46         self.bn2 = norm_layer(planes)
     47         self.downsample = downsample
     48         self.stride = stride
     49 
     50     def forward(self, x):
     51         identity = x
     52 
     53         out = self.conv1(x)
     54         out = self.bn1(out)
     55         out = self.relu(out)
     56 
     57         out = self.conv2(out)
     58         out = self.bn2(out)
     59 
     60         if self.downsample is not None:
     61             identity = self.downsample(x)
     62 
     63         out += identity
     64         out = self.relu(out)
     65 
     66         return out
     67 
     68 
     69 class Bottleneck(nn.Module):
     70     expansion = 4
     71 
     72     def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
     73                  base_width=64, dilation=1, norm_layer=None):
     74         super(Bottleneck, self).__init__()
     75         if norm_layer is None:
     76             norm_layer = nn.BatchNorm2d
     77         width = int(planes * (base_width / 64.)) * groups
     78         # Both self.conv2 and self.downsample layers downsample the input when stride != 1
     79         self.conv1 = conv1x1(inplanes, width)
     80         self.bn1 = norm_layer(width)
     81         self.conv2 = conv3x3(width, width, stride, groups, dilation)
     82         self.bn2 = norm_layer(width)
     83         self.conv3 = conv1x1(width, planes * self.expansion)
     84         self.bn3 = norm_layer(planes * self.expansion)
     85         self.relu = nn.ReLU(inplace=True)
     86         self.downsample = downsample
     87         self.stride = stride
     88 
     89     def forward(self, x):
     90         identity = x
     91 
     92         out = self.conv1(x)
     93         out = self.bn1(out)
     94         out = self.relu(out)
     95 
     96         out = self.conv2(out)
     97         out = self.bn2(out)
     98         out = self.relu(out)
     99 
    100         out = self.conv3(out)
    101         out = self.bn3(out)
    102 
    103         if self.downsample is not None:
    104             identity = self.downsample(x)
    105 
    106         out += identity
    107         out = self.relu(out)
    108 
    109         return out
    110 
    111 
    112 class ResNet(nn.Module):
    113 
    114     def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
    115                  groups=1, width_per_group=64, replace_stride_with_dilation=None,
    116                  norm_layer=None):
    117         super(ResNet, self).__init__()
    118         if norm_layer is None:
    119             norm_layer = nn.BatchNorm2d
    120         self._norm_layer = norm_layer
    121 
    122         self.inplanes = 64
    123         self.dilation = 1
    124         if replace_stride_with_dilation is None:
    125             # each element in the tuple indicates if we should replace
    126             # the 2x2 stride with a dilated convolution instead
    127             replace_stride_with_dilation = [False, False, False]
    128         if len(replace_stride_with_dilation) != 3:
    129             raise ValueError("replace_stride_with_dilation should be None "
    130                              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
    131         self.groups = groups
    132         self.base_width = width_per_group
    133         self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
    134                                bias=False)
    135         self.bn1 = norm_layer(self.inplanes)
    136         self.relu = nn.ReLU(inplace=True)
    137         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    138         self.layer1 = self._make_layer(block, 64, layers[0])
    139         self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
    140                                        dilate=replace_stride_with_dilation[0])
    141         self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
    142                                        dilate=replace_stride_with_dilation[1])
    143         self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
    144                                        dilate=replace_stride_with_dilation[2])
    145         self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    146         self.fc = nn.Linear(512 * block.expansion, num_classes)
    147 
    148         for m in self.modules():
    149             if isinstance(m, nn.Conv2d):
    150                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    151             elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
    152                 nn.init.constant_(m.weight, 1)
    153                 nn.init.constant_(m.bias, 0)
    154 
    155         # Zero-initialize the last BN in each residual branch,
    156         # so that the residual branch starts with zeros, and each residual block behaves like an identity.
    157         # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
    158         if zero_init_residual:
    159             for m in self.modules():
    160                 if isinstance(m, Bottleneck):
    161                     nn.init.constant_(m.bn3.weight, 0)
    162                 elif isinstance(m, BasicBlock):
    163                     nn.init.constant_(m.bn2.weight, 0)
    164 
    165     def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    166         norm_layer = self._norm_layer
    167         downsample = None
    168         previous_dilation = self.dilation
    169         if dilate:
    170             self.dilation *= stride
    171             stride = 1
    172         if stride != 1 or self.inplanes != planes * block.expansion:
    173             downsample = nn.Sequential(
    174                 conv1x1(self.inplanes, planes * block.expansion, stride),
    175                 norm_layer(planes * block.expansion),
    176             )
    177 
    178         layers = []
    179         layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
    180                             self.base_width, previous_dilation, norm_layer))
    181         self.inplanes = planes * block.expansion
    182         for _ in range(1, blocks):
    183             layers.append(block(self.inplanes, planes, groups=self.groups,
    184                                 base_width=self.base_width, dilation=self.dilation,
    185                                 norm_layer=norm_layer))
    186 
    187         return nn.Sequential(*layers)
    188 
    189     def forward(self, x):
    190         x = self.conv1(x)
    191         x = self.bn1(x)
    192         x = self.relu(x)
    193         x = self.maxpool(x)
    194 
    195         x = self.layer1(x)
    196         x = self.layer2(x)
    197         x = self.layer3(x)
    198         x = self.layer4(x)
    199 
    200         x = self.avgpool(x)
    201         x = x.reshape(x.size(0), -1)
    202         x = self.fc(x)
    203 
    204         return x
    205 
    206 
    207 def _resnet(arch, inplanes, planes, pretrained, progress, **kwargs):
    208     model = ResNet(inplanes, planes, **kwargs)
    209     if pretrained:
    210         state_dict = load_state_dict_from_url(model_urls[arch],
    211                                               progress=progress)
    212         model.load_state_dict(state_dict)
    213     return model
    214 
    215 
    216 def resnet18(pretrained=False, progress=True, **kwargs):
    217     """Constructs a ResNet-18 model.
    218     Args:
    219         pretrained (bool): If True, returns a model pre-trained on ImageNet
    220         progress (bool): If True, displays a progress bar of the download to stderr
    221     """
    222     return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
    223                    **kwargs)
    224 
    225 
    226 def resnet34(pretrained=False, progress=True, **kwargs):
    227     """Constructs a ResNet-34 model.
    228     Args:
    229         pretrained (bool): If True, returns a model pre-trained on ImageNet
    230         progress (bool): If True, displays a progress bar of the download to stderr
    231     """
    232     return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
    233                    **kwargs)
    234 
    235 
    236 def resnet50(pretrained=False, progress=True, **kwargs):
    237     """Constructs a ResNet-50 model.
    238     Args:
    239         pretrained (bool): If True, returns a model pre-trained on ImageNet
    240         progress (bool): If True, displays a progress bar of the download to stderr
    241     """
    242     return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
    243                    **kwargs)
    244 
    245 
    246 def resnet101(pretrained=False, progress=True, **kwargs):
    247     """Constructs a ResNet-101 model.
    248     Args:
    249         pretrained (bool): If True, returns a model pre-trained on ImageNet
    250         progress (bool): If True, displays a progress bar of the download to stderr
    251     """
    252     return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
    253                    **kwargs)
    254 
    255 
    256 def resnet152(pretrained=False, progress=True, **kwargs):
    257     """Constructs a ResNet-152 model.
    258     Args:
    259         pretrained (bool): If True, returns a model pre-trained on ImageNet
    260         progress (bool): If True, displays a progress bar of the download to stderr
    261     """
    262     return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
    263                    **kwargs)
    264 
    265 
    266 def resnext50_32x4d(**kwargs):
    267     kwargs['groups'] = 32
    268     kwargs['width_per_group'] = 4
    269     return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
    270                    pretrained=False, progress=True, **kwargs)
    271 
    272 
    273 def resnext101_32x8d(**kwargs):
    274     kwargs['groups'] = 32
    275     kwargs['width_per_group'] = 8
    276     return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
    277                    pretrained=False, progress=True, **kwargs)

    参考

    https://github.com/pytorch/vision/tree/master/torchvision/models

  • 相关阅读:
    Java基础——clone()方法浅析
    Unity shader error: “Too many texture interpolators would be used for ForwardBase pass”
    ar 解压一个.a文件报错: xxx.a is a fat file (use libtool(1) or lipo(1) and ar(1) on it)
    How to set up "lldb_codesign" certificate!
    Unity-iPhone has Conflicting Provisioning Settings
    ETC1/DXT1 compressed textures are not supported when publishing to iPhone
    Xcode 提交APP时遇到 “has one iOS Distribution certificate but its private key is not installed”
    XCode iOS之应用程序标题本地化
    苹果电脑(Mac mini或Macbook或iMac)恢复出厂设置
    Unity 4.7 导出工程在XCode10.1上编译报错
  • 原文地址:https://www.cnblogs.com/ys99/p/10872262.html
Copyright © 2011-2022 走看看