zoukankan      html  css  js  c++  java
  • SA-Net: Shuffle Attention for Deep Convolutional Neural Networks

    论文:https://arxiv.org/pdf/2102.00240.pdf

    代码:https://github.com/wofmanaf/SA-Net

    当前的 CNN 中的 attention 机制主要包括:channel attention 和 spatial attention,当前一些方法(GCNet 、CBAM 等)通常将二者集成,容易产生 converging difficulty 和 heavy computation burden 的问题。尽管 ECANet 和 SGE 提出了一些优化方案,但没有充分利用 channel 和 spatial 之间的关系。因此,作者提出一个问题 “ Can one fuse different attention modules in a lighter but more efficient way? ”

    为解决这个问题,作者提出了 shuffle attention,整体框架如下图所示。可以看出首先将输入的特征分为(g)组,然后每一组的特征进行split,分成两个分支,分别计算 channel attention 和 spatial attention,两种 attention 都使用全连接 + sigmoid 的方法计算。接着,两个分支的结果拼接到一起,然后合并,得到和输入尺寸一致的 feature map。 最后,用一个 shuffle 层进行处理。

    代码如下。 可以看出,在最后的 shuffle 部分,是直接分为两个组,然后置换进行组间交互。

    class sa_layer(nn.Module):
        def __init__(self, channel, groups=64):
            super(sa_layer, self).__init__()
            self.groups = groups
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.cweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
            self.cbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
            self.sweight = Parameter(torch.zeros(1, channel // (2 * groups), 1, 1))
            self.sbias = Parameter(torch.ones(1, channel // (2 * groups), 1, 1))
            self.sigmoid = nn.Sigmoid()
            self.gn = nn.GroupNorm(channel // (2 * groups), channel // (2 * groups))
    
        def forward(self, x):
            b, c, h, w = x.shape
    		# 将各个组与 n 合并在一维
            x = x.reshape(b * self.groups, -1, h, w)
            # 每组特征拆成 2 组,方便 2 分支处理
            x_0, x_1 = x.chunk(2, dim=1)
    
            # channel attention
            xn = self.avg_pool(x_0)
            xn = self.cweight * xn + self.cbias
            xn = x_0 * self.sigmoid(xn)
    
            # spatial attention
            xs = self.gn(x_1)
            xs = self.sweight * xs + self.sbias
            xs = x_1 * self.sigmoid(xs)
    
            # 沿 channel 方向合并
            out = torch.cat([xn, xs], dim=1)
            # 恢复与输入一致的 feature map 尺寸
            out = out.reshape(b, -1, h, w)
    		# 分为两个组进行 channel shuffle,后面有代码解析
            out = self.channel_shuffle(out, 2)
            return out
    

    Channel shuffle 的代码如下:

    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        # 因为要分组,先 reshape 成5个维度
        x = x.reshape(b, groups, -1, h, w)
        # 把 groups 和 channel 维度替换
        x = x.permute(0, 2, 1, 3, 4)
        # 恢复成输入的形状,实现 channel shuffle
        x = x.reshape(b, -1, h, w)
        return x
    

    实验部分可以参照原作者的论文,这里不多介绍。

  • 相关阅读:
    线程同步
    快捷键之Sublime
    快捷键之Idea
    快捷键之Chrome
    20155219付颖卓《网络对抗》逆向及Bof基础
    20155219 《嵌入式基础》
    20155219 《信息安全系统设计基础》课程总结
    20155219 《信息安全系统设计基础》第十四周学习总结
    20155219 《信息安全系统设计基础》第十三周学习总结
    2017-2018-1 20155219《信息安全系统设计基础》 实验五 通讯协议设计
  • 原文地址:https://www.cnblogs.com/gaopursuit/p/14403103.html
Copyright © 2011-2022 走看看