zoukankan      html  css  js  c++  java
  • pytorch实现resNet代码解析

    转载自https://blog.csdn.net/qq_31347869/article/details/100566719
    torchvision 是 pytorch 中一个很好用的包,主要由 3 个子包,分别是 torchvision.datasets,torchvision.models 和 torchvision.transforms
    参考官网:http://pytorch.org/docs/master/torchvision/index.html
    代码:https://github.com/pytorch/vision/tree/master/torchvision
    一、预训练模型
    在 torchvision 中实现了几个模型,包含 AlexNet,DenseNet,ResNet,VGG 等常用结构,并提供了预训练模型。
    导入模型:

    import torchvision
    model = torchvision.models.resnet50(pretrained=True)
    

    不需要预训练模型的参数来初始化

    model = torchvision.models.resnet50(pretrained=False)
    # pretrained参数默认是False,等价于
    model = torchvision.models.resnet50()
    

    二、不同层数ResNet模型的导入
    all 变量定义了可以从外部 import 的函数名或类名。根据 model_urls 的地址可以加载网络与训练权重

    import torch
    import torch.nn as nn
    from .utils import load_state_dict_from_url
    
    # 实现了不同层数的ResNet模型
    __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
               'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
               'wide_resnet50_2', 'wide_resnet101_2']
               
    model_urls = {
        'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
        'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
        'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
        'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
        'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
        'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
        'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
        'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
        'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
    }
    

    三、模型基本结构
    1、3*3的卷积模块
    卷积步长 stride=1,扩张大小 dilation=1(也就是 padding),in_planes 和 out_planes 分别是输入和输出的通道数,groups 是分组卷积参数,这里 groups=1 相当于没有分组。
    分组卷积参考:卷积神经网络中十大拍案叫绝的操作
    PyTorch中groups函数的作用

    def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
        """3x3 convolution with padding"""
        return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                         padding=dilation, groups=groups, bias=False, dilation=dilation)
    

    2、1*1的卷积模块
    卷积步长 stride=1

    def conv1x1(in_planes, out_planes, stride=1):
        """1x1 convolution"""
        return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
    

    3、BasicBlock模块
    基础模块BasicBlock里定义了 ResNet 最重要的残差模块,使用了两个 3*3 的卷积,卷积后接着 BN 和 ReLU。
    【1】super(BasicBlock, self).init() 这句是固定的标准写法。一般神经网络的类都继承自 torch.nn.Module,init() 和 forward() 是自定义类的两个主要函数,在自定义类的 init() 中需要添加一句 super(Net, self).init(),其中 Net 是自定义的类名,用于继承父类的初始化函数。注意在 init() 中只是对神经网络的模块进行了声明,真正的搭建是在 forward() 中实现。自定义类的成员都通过 self 指针来访问,所以参数列表中都包含了 self
    【2】out += identity 就是 ResNet 的精髓,在输出上叠加了输入 x
    【3】if self.downsample is not None 就是在进行下采样,如果需要的话

    class BasicBlock(nn.Module):
        expansion = 1
        __constants__ = ['downsample']
    
        def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None):
            super(BasicBlock, self).__init__()
            if norm_layer is None:
                norm_layer = nn.BatchNorm2d
            if groups != 1 or base_width != 64:
                raise ValueError('BasicBlock only supports groups=1 and base_width=64')
            if dilation > 1:
                raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
            # Both self.conv1 and self.downsample layers downsample the input when stride != 1
            self.conv1 = conv3x3(inplanes, planes, stride)
            self.bn1 = norm_layer(planes)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = conv3x3(planes, planes)
            self.bn2 = norm_layer(planes)
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            identity = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
    
            if self.downsample is not None:
                identity = self.downsample(x)
    
            out += identity
            out = self.relu(out)
    
            return out
    

    4、Bottleneck模块
    与BasicBlock不同的是,Bottleneck有3个卷积,即11, 33, 11大小的卷积核,分别用于压缩维度、卷积处理、恢复维度。
    inplanes 是输入通道数,planes 是输出通道数/expansion,expansion 是对输出通道数的倍乘,注意在基础版本 BasicBlock 中 expansion 是 1,此时相当于没有倍乘,输出的通道数就等于 planes。
    注意: 在使用 Bottleneck 时,它先对通道数进行压缩,再放大,所以传入的参数 planes 不是实际输出的通道数,而是 block 内部压缩后的通道数,真正的输出通道数为 plane
    expansion。
    这样做的主要目的是,使用 Bottleneck 结构可以减少网络参数数量。

    class Bottleneck(nn.Module):
        expansion = 4
    
        def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                     base_width=64, dilation=1, norm_layer=None):
            super(Bottleneck, self).__init__()
            if norm_layer is None:
                norm_layer = nn.BatchNorm2d
            width = int(planes * (base_width / 64.)) * groups
            # Both self.conv2 and self.downsample layers downsample the input when stride != 1
            self.conv1 = conv1x1(inplanes, width)
            self.bn1 = norm_layer(width)
            self.conv2 = conv3x3(width, width, stride, groups, dilation)
            self.bn2 = norm_layer(width)
            self.conv3 = conv1x1(width, planes * self.expansion)
            self.bn3 = norm_layer(planes * self.expansion)
            self.relu = nn.ReLU(inplace=True)
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            identity = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)
    
            out = self.conv3(out)
            out = self.bn3(out)
    
            if self.downsample is not None:
                identity = self.downsample(x)
    
            out += identity
            out = self.relu(out)
    
            return out
    

    如下图,左边的参数量为25633256=589824,右边的参数量为2561164 + 643364+6411256=69632,明显看到后者的参数量远小于前者。

    5、残差网络主体

  • 相关阅读:
    tomcat日志信息查看
    "".equals(xxx)和xxx.equals("")的区别
    javax.crypto.BadPaddingException: Given final block not properly padded解决方案
    去掉first li 的list图标
    浮动后的 <li> 如何在 <ul> 中居中显示?
    java冒泡排序
    JSP获取网络IP地址
    <%@ include %>导入的文件乱码
    out.print()与response.sendRedirect()
    王爽汇编语言第三版第5章实验4
  • 原文地址:https://www.cnblogs.com/zyr001/p/14550397.html
Copyright © 2011-2022 走看看