zoukankan      html  css  js  c++  java
  • VGG网络的Pytorch实现

    1.文章原文地址

    Very Deep Convolutional Networks for Large-Scale Image Recognition

    2.文章摘要

    在这项工作中,我们研究了在大规模的图像识别数据集上卷积神经网络的深度对准确率的影响。我们主要贡献是使用非常小(3×3)卷积核的架构对深度增加的网络进行全面的评估,其结果表明将深度增大到16-19层时网络的性能会显著提升。这些发现是基于我们在ImageNet Challenge 2014的目标检测和分类任务分别获得了第一名和第二名的成绩而得出的。另外该网络也可以很好的推广到其他数据集上,在这些数据集上获得了当前最好结果。我们已经公开了性能最佳的ConvNet模型,为了促进在计算机视觉中使用深度视觉表征的进一步研究。

    3.网络结构

    4.Pytorch实现

      1 import torch.nn as nn
      2 try:
      3     from torch.hub import load_state_dict_from_url
      4 except ImportError:
      5     from torch.utils.model_zoo import load_url as load_state_dict_from_url
      6 
      7 __all__ = [
      8     'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
      9     'vgg19_bn', 'vgg19',
     10 ]
     11 
     12 
     13 model_urls = {
     14     'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
     15     'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
     16     'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
     17     'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
     18     'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
     19     'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
     20     'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
     21     'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
     22 }
     23 
     24 
     25 class VGG(nn.Module):
     26 
     27     def __init__(self, features, num_classes=1000, init_weights=True):
     28         super(VGG, self).__init__()
     29         self.features = features
     30         self.avgpool = nn.AdaptiveAvgPool2d((7, 7))  #固定全连接层的输入
     31         self.classifier = nn.Sequential(
     32             nn.Linear(512 * 7 * 7, 4096),
     33             nn.ReLU(True),
     34             nn.Dropout(),
     35             nn.Linear(4096, 4096),
     36             nn.ReLU(True),
     37             nn.Dropout(),
     38             nn.Linear(4096, num_classes),
     39         )
     40         if init_weights:
     41             self._initialize_weights()
     42 
     43     def forward(self, x):
     44         x = self.features(x)
     45         x = self.avgpool(x)
     46         x = x.view(x.size(0), -1)
     47         x = self.classifier(x)
     48         return x
     49 
     50     def _initialize_weights(self):
     51         for m in self.modules():
     52             if isinstance(m, nn.Conv2d):
     53                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
     54                 if m.bias is not None:
     55                     nn.init.constant_(m.bias, 0)
     56             elif isinstance(m, nn.BatchNorm2d):
     57                 nn.init.constant_(m.weight, 1)
     58                 nn.init.constant_(m.bias, 0)
     59             elif isinstance(m, nn.Linear):
     60                 nn.init.normal_(m.weight, 0, 0.01)
     61                 nn.init.constant_(m.bias, 0)
     62 
     63 
     64 def make_layers(cfg, batch_norm=False):
     65     layers = []
     66     in_channels = 3
     67     for v in cfg:
     68         if v == 'M':
     69             layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
     70         else:
     71             conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
     72             if batch_norm:
     73                 layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
     74             else:
     75                 layers += [conv2d, nn.ReLU(inplace=True)]
     76             in_channels = v
     77     return nn.Sequential(*layers)
     78 
     79 
     80 cfgs = {
     81     'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
     82     'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
     83     'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
     84     'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
     85 }
     86 
     87 
     88 def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
     89     if pretrained:
     90         kwargs['init_weights'] = False
     91     model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
     92     if pretrained:
     93         state_dict = load_state_dict_from_url(model_urls[arch],
     94                                               progress=progress)
     95         model.load_state_dict(state_dict)
     96     return model
     97 
     98 
     99 def vgg11(pretrained=False, progress=True, **kwargs):
    100     """VGG 11-layer model (configuration "A")
    101     Args:
    102         pretrained (bool): If True, returns a model pre-trained on ImageNet
    103         progress (bool): If True, displays a progress bar of the download to stderr
    104     """
    105     return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs)
    106 
    107 
    108 def vgg11_bn(pretrained=False, progress=True, **kwargs):
    109     """VGG 11-layer model (configuration "A") with batch normalization
    110     Args:
    111         pretrained (bool): If True, returns a model pre-trained on ImageNet
    112         progress (bool): If True, displays a progress bar of the download to stderr
    113     """
    114     return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs)
    115 
    116 
    117 def vgg13(pretrained=False, progress=True, **kwargs):
    118     """VGG 13-layer model (configuration "B")
    119     Args:
    120         pretrained (bool): If True, returns a model pre-trained on ImageNet
    121         progress (bool): If True, displays a progress bar of the download to stderr
    122     """
    123     return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs)
    124 
    125 
    126 def vgg13_bn(pretrained=False, progress=True, **kwargs):
    127     """VGG 13-layer model (configuration "B") with batch normalization
    128     Args:
    129         pretrained (bool): If True, returns a model pre-trained on ImageNet
    130         progress (bool): If True, displays a progress bar of the download to stderr
    131     """
    132     return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs)
    133 
    134 
    135 def vgg16(pretrained=False, progress=True, **kwargs):
    136     """VGG 16-layer model (configuration "D")
    137     Args:
    138         pretrained (bool): If True, returns a model pre-trained on ImageNet
    139         progress (bool): If True, displays a progress bar of the download to stderr
    140     """
    141     return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs)
    142 
    143 
    144 def vgg16_bn(pretrained=False, progress=True, **kwargs):
    145     """VGG 16-layer model (configuration "D") with batch normalization
    146     Args:
    147         pretrained (bool): If True, returns a model pre-trained on ImageNet
    148         progress (bool): If True, displays a progress bar of the download to stderr
    149     """
    150     return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs)
    151 
    152 
    153 def vgg19(pretrained=False, progress=True, **kwargs):
    154     """VGG 19-layer model (configuration "E")
    155     Args:
    156         pretrained (bool): If True, returns a model pre-trained on ImageNet
    157         progress (bool): If True, displays a progress bar of the download to stderr
    158     """
    159     return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs)
    160 
    161 
    162 def vgg19_bn(pretrained=False, progress=True, **kwargs):
    163     """VGG 19-layer model (configuration 'E') with batch normalization
    164     Args:
    165         pretrained (bool): If True, returns a model pre-trained on ImageNet
    166         progress (bool): If True, displays a progress bar of the download to stderr
    167     """
    168     return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs)

     参考

    https://github.com/pytorch/vision/tree/master/torchvision/models

  • 相关阅读:
    C# 发送匿名邮件
    上传大文件,Web.config中的配置
    老话题关于文章自动分页
    scrollLeft,scrollWidth,clientWidth,offsetWidth到底指的哪到哪的距离
    转成静态页面,由于ie网址或路径原因,Atlas失效。
    让图片自适应大小的方法
    textoverflow 全兼容
    ISAPI_rewrite中文手册
    ISAPI_Rewrite集
    无限级下拉列表框控件
  • 原文地址:https://www.cnblogs.com/ys99/p/10835805.html
Copyright © 2011-2022 走看看