zoukankan      html  css  js  c++  java
  • DenseNet笔记

    一、DenseNet的优点

    • 减轻梯度消失问题
    • 加强特征的传递
    • 充分利用特征
    • 减少了参数量

    二、网络结构公式

    对于每一个DenseBlock中的每一个层,

    [x0,x1,…,xl-1]表示将0到l-1层的输出feature map做concatenation。concatenation是做通道的合并,就像Inception那样。而前面resnet是做值的相加,通道数是不变的。Hl包括BN,ReLU和3*3的卷积。

    而在ResNet中的每一个残差块,

    三、Growth Rate

    指的是DenseBlock中每一个非线性变换Hl(BN,ReLU和3*3的卷积)的输出,这个输出与输入Concate.一个DenseBlock的输出=输入+Hl数×growth_rate。在要给DenseBlock中,Feature Map的size保持不变。

    四、Bottleneck

    这个组件位于DenseBlock中,当一个DenseBlock包含的非线性变换Hl较多时(如nHl=48),此时的grow rate为k=32,那么第48层的输入变成input+47×32,这是一个很大的数,如果不用bottleneck进行降维,那么计算量很大。

    因此,使用4×k个1x1卷积进行降维。使得3×3线性变换的输入通道变成4×k。同时,bottleneck起到特征融合的效果。

    五、Transition

    这个组件位于DenseBlock之间,使用1×1卷积进行降维,降维后的通道数为input_channels*reduction. 参数reduction默认为0.5,后接池化层进行下采样,减小Feature Map 分辨率。

    六、网络结构

     

    七、代码实现(Pytorch)

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import math
    
    class Bottleneck(nn.Module):
        def __init__(self,nChannels,growthRate):
            super(Bottleneck,self).__init__()
            interChannels = 4*growthRate
            self.bn1 = nn.BatchNorm2d(nChannels)
            self.conv1 = nn.Conv2d(nChannels,interChannels,kernel_size=1,
                                   stride=1,bias=False)
            self.bn2 = nn.BatchNorm2d(interChannels)
            self.conv2 = nn.Conv2d(interChannels,growthRate,kernel_size=3,
                                   stride=1,padding=1,bias=False)
    
        def forward(self, *input):
            #先进行BN(pytorch的BN已经包含了Scale),然后进行relu,conv1起到bottleneck的作用
            out = self.conv1(F.relu(self.bn1(input)))
            out = self.conv2(F.relu(self.bn2(out)))
            out = torch.cat(input,out)
            return out
    
    
    class SingleLayer(nn.Module):
        def __init__(self,nChannels,growthRate):
            super(SingleLayer,self).__init__()
            self.bn1 = nn.BatchNorm2d(nChannels)
            self.conv1 = nn.Conv2d(nChannels,growthRate,kernel_size=3,
                                   padding=1,bias=False)
    
        def forward(self, *input):
            out = self.conv1(F.relu(self.bn1(input)))
            out = torch.cat(input,out)
            return out
    
    class Transition(nn.Module):
        def __int__(self,nChannels,nOutChannels):
            super(Transition,self).__init__()
    
            self.bn1 = nn.BatchNorm2d(nChannels)
            self.conv1 = nn.Conv2d(nChannels,nOutChannels,kernel_size=1,bias=False)
    
        def forward(self, *input):
            out = self.conv1(F.relu(self.bn1(input)))
            out = F.avg_pool2d(out,2)
            return out
    
    class DenseNet(nn.Module):
        def __init__(self,growthRate,depth,reduction,nClasses,bottleneck):
            super(DenseNet,self).__init__()
            #DenseBlock中非线性变换模块的个数
            nNoneLinears = (depth-4)//3
            if bottleneck:
                nNoneLinears //=2
    
            nChannels = 2*growthRate
            self.conv1 = nn.Conv2d(3,nChannels,kernel_size=3,padding=1,bias=False)
            self.denseblock1 = self._make_dense(nChannels,growthRate,nNoneLinears,bottleneck)
            nChannels += nNoneLinears*growthRate
            nOutChannels = int(math.floor(nChannels*reduction))        #向下取整
            self.transition1 = Transition(nChannels,nOutChannels)
    
            nChannels = nOutChannels
            self.denseblock2 = self._make_dense(nChannels,growthRate,nNoneLinears,bottleneck)
            nChannels += nNoneLinears*growthRate
            nOutChannels = int(math.floor(nChannels*reduction))
            self.transition2 = Transition(nChannels, nOutChannels)
    
            nChannels = nOutChannels
            self.denseblock3 = self._make_dense(nChannels, growthRate, nNoneLinears, bottleneck)
            nChannels += nNoneLinears * growthRate
    
            self.bn1 = nn.BatchNorm2d(nChannels)
            self.fc = nn.Linear(nChannels,nClasses)
    
            #参数初始化
            for m in self.modules():
                if isinstance(m,nn.Conv2d):
                    n = m.kernel_size[0]*m.kernel_size[1]*m.out_channels
                    m.weight.data.normal_(0,math.sqrt(2./n))
                elif isinstance(m,nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()
                elif isinstance(m,nn.Linear):
                    m.bias.data.zero_()
    
        def _make_dense(self,nChannels,growthRate,nDenseBlocks,bottleneck):
            layers = []
            for i in range(int(nDenseBlocks)):
                if bottleneck:
                    layers.append(Bottleneck(nChannels,growthRate))
                else:
                    layers.append(SingleLayer(nChannels,growthRate))
            nChannels+=growthRate
            return nn.Sequential(*layers)
    
        def forward(self, *input):
            out = self.conv1(input)
            out = self.transition1(self.denseblock1(out))
            out = self.transition2(self.denseblock2(out))
            out = self.denseblock3(out)
            out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)),8))
            out = F.log_softmax(self.fc(out))
            return out
  • 相关阅读:
    网易云易盾牵手百视通 助力广电领域新媒体内容安全
    理解DDoS防护本质:基于资源较量和规则过滤的智能化系统
    DDoS防护之TCP防护
    2017年内容安全十大事件盘点
    知物由学 | AI时代,那些黑客正在如何打磨他们的“利器”?(一)
    应对羊毛党的老手段不管用了,但有些公司依然有办法,他们是怎么做的?
    知物由学 | 未来安全隐患:AI的软肋——故意欺骗神经网络
    MYSQL数据库的数据完整性
    MYSQL是什么?
    python多线程实现多任务
  • 原文地址:https://www.cnblogs.com/houjun/p/10250546.html
Copyright © 2011-2022 走看看