zoukankan      html  css  js  c++  java
  • 残差网络ResNet(超详细代码解析) :你必须要知道backbone模块成员之一

     

             本文主要贡献代码模块(文末),在本文中对resnet进行了复现,是一份原始版本模块,里面集成了权重文件pth的载入模块(如函数:init_weights(self, pretrained=None)),layers的冻结模块(如函数:_freeze_stages(self)),更是将其改写成可读性高的代码,若你需要执行该模块,可直接将其代码模块粘贴成.py文件即可。而理论模块,并非本文重点,因此借鉴博客:https://zhuanlan.zhihu.com/p/42706477 ,我将不再说明:

    注:本人也意在改写更多backbones模块,后续将会放入该github中,可供代码下载:https://github.com/tangjunjun966/backbones

    本文机构:1.基本原理;2. Resnet代码复现;3.代码运行结果展示

     1.基本原理

    ResNet的作者何凯明也因此摘得CVPR2016最佳论文奖,当然何博士的成就远不止于此,感兴趣的可以去搜一下他后来的辉煌战绩。那么ResNet为什么会有如此优异的表现呢?其实ResNet是解决了深度CNN模型难训练的问题,从图2中可以看到14年的VGG才19层,而15年的ResNet多达152层,这在网络深度完全不是一个量级上,所以如果是第一眼看这个图的话,肯定会觉得ResNet是靠深度取胜。事实当然是这样,但是ResNet还有架构上的trick,这才使得网络的深度发挥出作用,这个trick就是残差学习(Residual learning)。下面详细讲述ResNet的理论及实现。

    深度网络的退化问题

    从经验来看,网络的深度对模型的性能至关重要,当增加网络层数后,网络可以进行更加复杂的特征模式的提取,所以当模型更深时理论上可以取得更好的结果,从图2中也可以看出网络越深而效果越好的一个实践证据。但是更深的网络其性能一定会更好吗?实验发现深度网络出现了退化问题(Degradation problem):网络深度增加时,网络准确度出现饱和,甚至出现下降。这个现象可以在图3中直观看出来:56层的网络比20层网络效果还要差。这不会是过拟合问题,因为56层网络的训练误差同样高。我们知道深层网络存在着梯度消失或者爆炸的问题,这使得深度学习模型很难训练。但是现在已经存在一些技术手段如BatchNorm来缓解这个问题。因此,出现深度网络的退化问题是非常令人诧异的。

    图3 20层与56层网络在CIFAR-10上的误差

    残差学习

    深度网络的退化问题至少说明深度网络不容易训练。但是我们考虑这样一个事实:现在你有一个浅层网络,你想通过向上堆积新层来建立深层网络,一个极端情况是这些增加的层什么也不学习,仅仅复制浅层网络的特征,即这样新层是恒等映射(Identity mapping)。在这种情况下,深层网络应该至少和浅层网络性能一样,也不应该出现退化现象。好吧,你不得不承认肯定是目前的训练方法有问题,才使得深层网络很难去找到一个好的参数。

    这个有趣的假设让何博士灵感爆发,他提出了残差学习来解决退化问题。对于一个堆积层结构(几层堆积而成)当输入为 [公式] 时其学习到的特征记为 [公式] ,现在我们希望其可以学习到残差 [公式] ,这样其实原始的学习特征是 [公式] 。之所以这样是因为残差学习相比原始特征直接学习更容易。当残差为0时,此时堆积层仅仅做了恒等映射,至少网络性能不会下降,实际上残差不会为0,这也会使得堆积层在输入特征基础上学习到新的特征,从而拥有更好的性能。残差学习的结构如图4所示。这有点类似与电路中的“短路”,所以是一种短路连接(shortcut connection)。

    图4 残差学习单元

    为什么残差学习相对更容易,从直观上看残差学习需要学习的内容少,因为残差一般会比较小,学习难度小点。不过我们可以从数学的角度来分析这个问题,首先残差单元可以表示为:

    [公式]

    其中 [公式] 和 [公式] 分别表示的是第 [公式] 个残差单元的输入和输出,注意每个残差单元一般包含多层结构。 [公式] 是残差函数,表示学习到的残差,而 [公式] 表示恒等映射, [公式] 是ReLU激活函数。基于上式,我们求得从浅层 [公式] 到深层 [公式] 的学习特征为:

    [公式]

    利用链式规则,可以求得反向过程的梯度:

    [公式]

    式子的第一个因子 [公式] 表示的损失函数到达 [公式] 的梯度,小括号中的1表明短路机制可以无损地传播梯度,而另外一项残差梯度则需要经过带有weights的层,梯度不是直接传递过来的。残差梯度不会那么巧全为-1,而且就算其比较小,有1的存在也不会导致梯度消失。所以残差学习会更容易。要注意上面的推导并不是严格的证明。

    ResNet的网络结构

    ResNet网络是参考了VGG19网络,在其基础上进行了修改,并通过短路机制加入了残差单元,如图5所示。变化主要体现在ResNet直接使用stride=2的卷积做下采样,并且用global average pool层替换了全连接层。ResNet的一个重要设计原则是:当feature map大小降低一半时,feature map的数量增加一倍,这保持了网络层的复杂度。从图5中可以看到,ResNet相比普通网络每两层间增加了短路机制,这就形成了残差学习,其中虚线表示feature map数量发生了改变。图5展示的34-layer的ResNet,还可以构建更深的网络如表1所示。从表中可以看到,对于18-layer和34-layer的ResNet,其进行的两层间的残差学习,当网络更深时,其进行的是三层间的残差学习,三层卷积核分别是1x1,3x3和1x1,一个值得注意的是隐含层的feature map数量是比较小的,并且是输出feature map数量的1/4。

    图5 ResNet网络结构图表1 不同深度的ResNet

    下面我们再分析一下残差单元,ResNet使用两种残差单元,如图6所示。左图对应的是浅层网络,而右图对应的是深层网络。对于短路连接,当输入和输出维度一致时,可以直接将输入加到输出上。但是当维度不一致时(对应的是维度增加一倍),这就不能直接相加。有两种策略:(1)采用zero-padding增加维度,此时一般要先做一个downsamp,可以采用strde=2的pooling,这样不会增加参数;(2)采用新的映射(projection shortcut),一般采用1x1的卷积,这样会增加参数,也会增加计算量。短路连接除了直接使用恒等映射,当然都可以采用projection shortcut。

    图6 不同的残差单元

    作者对比18-layer和34-layer的网络效果,如图7所示。可以看到普通的网络出现退化现象,但是ResNet很好的解决了退化问题。

    图7 18-layer和34-layer的网络效果

    最后展示一下ResNet网络与其他网络在ImageNet上的对比结果,如表2所示。可以看到ResNet-152其误差降到了4.49%,当采用集成模型后,误差可以降到3.57%。

    表2 ResNet与其他网络的对比结果

    说一点关于残差单元题外话,上面我们说到了短路连接的几种处理方式,其实作者在文献[2]中又对不同的残差单元做了细致的分析与实验,这里我们直接抛出最优的残差结构,如图8所示。改进前后一个明显的变化是采用pre-activation,BN和ReLU都提前了。而且作者推荐短路连接采用恒等变换,这样保证短路连接不会有阻碍。感兴趣的可以去读读这篇文章。

     2.Resnet代码复现

    """
    @author: tangjun
    @contact: 511026664@qq.com
    @time: 2020/12/7 22:48
    @desc: 残差ackbone
    """
    
    import torch.nn as nn
    import torch
    from collections import OrderedDict
    
    
    def Conv(in_planes, out_planes, **kwargs):
        "3x3 convolution with padding"
        padding = kwargs.get('padding', 1)
        bias = kwargs.get('bias', False)
        stride = kwargs.get('stride', 1)
        kernel_size = kwargs.get('kernel_size', 3)
        out = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        return out
    
    
    class BasicBlock(nn.Module):
        expansion = 1
    
        def __init__(self, inplanes, planes, stride=1, downsample=None):
            super(BasicBlock, self).__init__()
            self.conv1 = Conv(inplanes, planes, stride=stride)
            self.bn1 = nn.BatchNorm2d(planes)
            self.relu = nn.ReLU(inplace=True)
            self.conv2 = Conv(planes, planes)
            self.bn2 = nn.BatchNorm2d(planes)
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            residual = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
    
            if self.downsample is not None:
                residual = self.downsample(x)
    
            out += residual
            out = self.relu(out)
    
            return out
    
    
    class Bottleneck(nn.Module):
        expansion = 4
    
        def __init__(self, inplanes, planes, stride=1, downsample=None):
            super(Bottleneck, self).__init__()
            self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
            self.bn1 = nn.BatchNorm2d(planes)
            self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                                   padding=1, bias=False)
            self.bn2 = nn.BatchNorm2d(planes)
            self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
            self.bn3 = nn.BatchNorm2d(planes * 4)
            self.relu = nn.ReLU(inplace=True)
            self.downsample = downsample
            self.stride = stride
    
        def forward(self, x):
            residual = x
    
            out = self.conv1(x)
            out = self.bn1(out)
            out = self.relu(out)
    
            out = self.conv2(out)
            out = self.bn2(out)
            out = self.relu(out)
    
            out = self.conv3(out)
            out = self.bn3(out)
    
            if self.downsample is not None:
                residual = self.downsample(x)
    
            out += residual
            out = self.relu(out)
    
            return out
    
    
    class Resnet(nn.Module):
        arch_settings = {
            18: (BasicBlock, (2, 2, 2, 2)),
            34: (BasicBlock, (3, 4, 6, 3)),
            50: (Bottleneck, (3, 4, 6, 3)),
            101: (Bottleneck, (3, 4, 23, 3)),
            152: (Bottleneck, (3, 8, 36, 3))
        }
    
        def __init__(self, depth,
                     in_channels=None,
                     pretrained=None,
    
                    frozen_stages=-1
    
                     # num_classes=None
                     ):
            self.inplanes = 64
            super(Resnet, self).__init__()
    
            self.inchannels = in_channels if in_channels is not None else 3  # 输入通道
            # self.num_classes=num_classes
            self.block, layers = self.arch_settings[depth]
            self.frozen_stages=frozen_stages
            self.conv1 = nn.Conv2d(self.inchannels, 64, kernel_size=7, stride=2, padding=3, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
            self.relu = nn.ReLU(inplace=True)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    
            self.layer1 = self._make_layer(self.block, 64, layers[0], stride=1)
            self.layer2 = self._make_layer(self.block, 128, layers[1], stride=2)
            self.layer3 = self._make_layer(self.block, 256, layers[2], stride=2)
            self.layer4 = self._make_layer(self.block, 512, layers[3], stride=2)
    
            # self.avgpool = nn.AvgPool2d(7)
            # self.fc = nn.Linear(512 * self.block.expansion, self.num_classes)
            self._freeze_stages()  # 冻结函数
        def _freeze_stages(self):
            if self.frozen_stages >= 0:
                self.norm1.eval()
                for m in [self.conv1, self.norm1]:
                    for param in m.parameters():
                        param.requires_grad = False
    
            for i in range(1, self.frozen_stages + 1):
                m = getattr(self, 'layer{}'.format(i))
                m.eval()
                for param in m.parameters():
                    param.requires_grad = False
    
    
        def init_weights(self, pretrained=None):
            if isinstance(pretrained, str):
                self.load_checkpoint(pretrained)
            elif pretrained is None:
                for m in self.modules():
                    if isinstance(m, nn.Conv2d):
                        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out', nonlinearity='relu')
                        if hasattr(m, 'bias') and m.bias is not None: # m包含该属性且m.bias非None # hasattr(对象,属性)表示对象是否包含该属性
                            nn.init.constant_(m.bias, 0)
    
                    elif isinstance(m, nn.BatchNorm2d):
                        m.weight.data.fill_(1)
                        m.bias.data.zero_()
    
        def load_checkpoint(self, pretrained):
    
            checkpoint = torch.load(pretrained)
            if isinstance(checkpoint, OrderedDict):
                state_dict = checkpoint
            elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
    
            if list(state_dict.keys())[0].startswith('module.'):
                state_dict = {k[7:]: v for k, v in checkpoint['state_dict'].items()}
    
            unexpected_keys = []  # 保存checkpoint不在module中的key
            model_state = self.state_dict()  # 模型变量
    
            for name, param in state_dict.items():  # 循环遍历pretrained的权重
                if name not in model_state:
                    unexpected_keys.append(name)
                    continue
                if isinstance(param, torch.nn.Parameter):
                    # backwards compatibility for serialized parameters
                    param = param.data
    
                try:
                    model_state[name].copy_(param)  # 试图赋值给模型
                except Exception:
                    raise RuntimeError(
                        'While copying the parameter named {}, '
                        'whose dimensions in the model are {} not equal '
                        'whose dimensions in the checkpoint are {}.'.format(
                            name, model_state[name].size(), param.size()))
            missing_keys = set(model_state.keys()) - set(state_dict.keys())
            print('missing_keys:',missing_keys)
        def _make_layer(self, block, planes, num_blocks, stride=1):
            downsample = None
            if stride != 1 or self.inplanes != planes * block.expansion:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion,
                              kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
                )
    
            layers = []
            layers.append(block(self.inplanes, planes, stride, downsample))
            self.inplanes = planes * block.expansion
            for i in range(1, num_blocks):
                layers.append(block(self.inplanes, planes))
    
            return nn.Sequential(*layers)
    
        def forward(self, x):
            outs = []
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)
    
            x = self.layer1(x)
            outs.append(x)
            x = self.layer2(x)
            outs.append(x)
            x = self.layer3(x)
            outs.append(x)
            x = self.layer4(x)
            outs.append(x)
    
            # x = self.avgpool(x)
            # x = x.view(x.size(0), -1)
            # x = self.fc(x)
    
            return tuple(outs)
    
    
    if __name__ == '__main__':
        x = torch.ones((2, 3, 215, 215))
        model = Resnet(depth=50)
    
        model.init_weights(pretrained='./resnet50.pth')
    
    
        out = model(x)
    
        print(out)

     3.代码运行结果展示

    处理算法通用的辅助的code,如读取txt文件,读取xml文件,将xml文件转换成txt文件,读取json文件等
  • 相关阅读:
    Teamviewer12完美破解版去除时间限制完美使用
    winFrom程序更新自动安装
    sql 根据指定字符截取前面几个字符
    下拉框带搜索
    easyui 展开缩起
    p1001 谁拿了最多的奖学金
    p1217 乒乓球
    p1911 珠心算问题
    p1848 记数问题
    入坑archlinux
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/14947868.html
Copyright © 2011-2022 走看看