zoukankan      html  css  js  c++  java
  • 【深度学习】基于Pytorch的ResNet实现

    1. ResNet理论

    论文:https://arxiv.org/pdf/1512.03385.pdf

    残差学习基本单元:

    img

    在ImageNet上的结果:

    效果会随着模型层数的提升而下降,当更深的网络能够开始收敛时,就会出现降级问题:随着网络深度的增加,准确度变得饱和(这可能不足为奇),然后迅速降级。

    ResNet模型:

    2. pytorch实现

    2.1 基础卷积

    conv3$ imes(3 和conv1) imes$1 基础模块

    def conv3x3(in_channel, out_channel, stride=1, groups=1, dilation=1):
        return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
    
    def conv1x1(in_channel, out_channel, stride=1):
        return nn.Conv2d(in_channel, out_channel, kernel_size=1, bias=False)
    

    参数解释:

    in_channel: 输入的通道数目

    out_channel:输出的通道数目

    stride, padding: 步长和补0

    dilation: 空洞卷积中的参数

    groups: 从输入通道到输出通道的阻塞连接数

    feature size 计算:
    output = (intput - filter_size + 2 x padding) / stride + 1
    

    空洞卷积实际卷积核大小:

    K = K + (K-1)x(R-1)
    K 是原始卷积核大小
    R 是空洞卷积参数的空洞率(普通卷积为1)
    

    2.2 模块

    - resnet34
    	- _resnet
    		- ResNet
    			- _make_layer
    				- block 
    					- Bottleneck
    					- BasicBlock			
    

    Bottlenect

    class Bottleneck(nn.Module):
        expansion = 4
        __constants__ = ['downsample']
    
        def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                     base_width=64, dilation=1, norm_layer=None):
            super(Bottleneck, self).__init__()
            if norm_layer is None:
                norm_layer = nn.BatchNorm2d
            width = int(planes * (base_width / 64.)) * groups
            # Both self.conv2 and self.downsample layers downsample the input when stride != 1
            self.conv1 = conv1x1(inplanes, width)
            self.bn1 = norm_layer(width)
            self.conv2 = conv3x3(width, width, stride, groups, dilation)
            self.bn2 = norm_layer(width)
            self.conv3 = conv1x1(width, planes * self.expansion)
            self.bn3 = norm_layer(planes * self.expansion)
            self.relu = nn.ReLU(inplace=True)
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            identity = x
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)
            out = self.conv3(out)
            out = self.bn3(out)
            if self.downsample is not None:
                identity = self.downsample(x)
            out += identity
            out = self.relu(out)
            return out
    

    BasicBlock

    class BasicBlock(nn.Module):
        expansion = 1
        __constants__ = ['downsample']
    
        def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                     base_width=64, dilation=1, norm_layer=None):
            super(BasicBlock, self).__init__()
            if norm_layer is None:
                norm_layer = nn.BatchNorm2d
            if groups != 1 or base_width != 64:
                raise ValueError('BasicBlock only supports groups=1 and base_width=64')
            if dilation > 1:
                raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
            # Both self.conv1 and self.downsample layers downsample the input when stride != 1
            self.conv1 = conv3x3(inplanes, planes, stride)
            self.bn1 = norm_layer(planes)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = conv3x3(planes, planes)
            self.bn2 = norm_layer(planes)
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            identity = x
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
            out = self.conv2(out)
            out = self.bn2(out)
            if self.downsample is not None:
                identity = self.downsample(x)
            out += identity
            out = self.relu(out)
            return out
    

    2.3 使用ResNet模块进行迁移学习

    import torchvision.models as models
    import torch.nn as nn
    
    class RES18(nn.Module):
        def __init__(self):
            super(RES18, self).__init__()
            self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
            self.base = torchvision.models.resnet18(pretrained=False)
            self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
        def forward(self, x):
            out = self.base(x)
            return out
    
    class RES34(nn.Module):
        def __init__(self):
            super(RES34, self).__init__()
            self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
            self.base = torchvision.models.resnet34(pretrained=False)
            self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
        def forward(self, x):
            out = self.base(x)
            return out
    
    class RES50(nn.Module):
        def __init__(self):
            super(RES50, self).__init__()
            self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
            self.base = torchvision.models.resnet50(pretrained=False)
            self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
        def forward(self, x):
            out = self.base(x)
            return out
    
    class RES101(nn.Module):
        def __init__(self):
            super(RES101, self).__init__()
            self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
            self.base = torchvision.models.resnet101(pretrained=False)
            self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
        def forward(self, x):
            out = self.base(x)
            return out
    
    class RES152(nn.Module):
        def __init__(self):
            super(RES152, self).__init__()
            self.num_cls = settings.MAX_CAPTCHA*settings.ALL_CHAR_SET_LEN
            self.base = torchvision.models.resnet152(pretrained=False)
            self.base.fc = nn.Linear(self.base.fc.in_features, self.num_cls)
        def forward(self, x):
            out = self.base(x)
            return out
    

    使用模块直接生成一个类即可,比如训练的时候:

    cnn = RES101()
    cnn.train() # 改为训练模式
    prediction = cnn(img) #进行预测
    

    目前先写这么多,看过了源码以后感觉写的很好,不仅仅有论文中最基础的部分,还有一些额外的功能,模块的组织也很整齐。

    平时使用一般都进行迁移学习,使用的话可以把上述几个类中pretrained=False参数改为True.

    实战篇:以上迁移学习代码来自我的一个小项目,验证码识别,地址:https://github.com/pprp/captcha_identify.torch

  • 相关阅读:
    【摄影】延时摄影
    【sas sql proc】统计
    【分析模板】excel or sas
    JavaScript的方法和技巧
    好书推荐
    七招制胜ASP.NET应用程序开发
    .Net中使用带返回值的存储过程(VB代码)
    ASP.NET 2.0构建动态导航的Web应用程序(TreeView和Menu )
    简单查询和联合查询两方面介绍SQL查询语句
    数字金额大小写转换之存储过程
  • 原文地址:https://www.cnblogs.com/pprp/p/11721587.html
Copyright © 2011-2022 走看看