zoukankan      html  css  js  c++  java
  • Toward fast and accurate human pose estimation via soft-gated skip connections

    Toward fast and accurate human pose estimation via soft-gated skip connections

    一. 论文简介

    设计小的block和feature map的融合方式,提人体姿态估计高计算效率和精度

    主要做的贡献如下(可能之前有人已提出):

    1. block部分使用soft-gate
    2. feature融合对比,选择最佳方式

    二. 模块详解

    2.1 Soft-Gate模块

    • 简单的说明就是下图展示所示,给每个channel加上了attention机制,这种做法其实SE已经完成了,就是中间的步骤有点小区别而已
    • 以下结合SE的权重,给出参考代码,未尝试运行,整体思路是这样
    #FIXME This code is demo for paper, maybe cant run directly, please modify some bug when you use.
    class ConvBNReLU(nn.Sequential):
        '''
            #FIXME Only for 3*3 and 1*1 convolution by dilation equal 1 or 2
        '''
        def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, dilation=1, norm_layer=None):
            padding = (kernel_size - 1) // 2
            if dilation==2 and kernel_size==3 and stride==1:
                padding=2
            if norm_layer is None:
                norm_layer = nn.BatchNorm2d
            super(ConvBNReLU, self).__init__(
                nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding=padding, groups=groups, dilation=dilation, bias=False),
                norm_layer(out_planes),
                nn.ReLU6(inplace=True)
            )
    
    class SoftGateBlock(nn.Module):
        def __init__(self, inp, gate=SqueezeExcite):
            super(SoftGateBlock, self).__init__()
            assert inp%4 == 0
            self.layers = self.build_layer(inp, gate)
    
        def forward(self, x):
            alpha = self.layers[0](x)
            branch_y1 = self.layers[1](x)
            branch_y2 = self.layers[1](branch_y1)
            branch_y3 = self.layers[1](branch_y2)
            branch = torch.cat([branch_y1,branch_y2,branch_y3],dim=1)
            y = alpha + branch #TODO please add dimension (torch.unsqueeze) to align dim if error
            return y
    
        def build_layer(self, chs, gate):
            layers = []
            layers.append(gate(chs))
            for i in range(3):
                layers.append(ConvBNReLU(chs, chs/2))
                chs = chs/2
            return nn.ModuleList(layers)
    
    # SE module that attetion mode
    class SqueezeExcite(nn.Module):
        def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
                     act_layer=nn.ReLU, gate_fn=hard_sigmoid, divisor=4, **_):
            super(SqueezeExcite, self).__init__()
            self.gate_fn = gate_fn
            reduced_chs = _make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
            self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
            self.act1 = act_layer(inplace=True)
            self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
    
        def forward(self, x):
            x_se = self.avg_pool(x)
            x_se = self.conv_reduce(x_se)
            x_se = self.act1(x_se)
            x_se = self.conv_expand(x_se)
            x = x * self.gate_fn(x_se)
            return x    
    

    论文给的细节不完整,比如:

    1. 如何得到获得这个权重 (alpha) ?方法有很多,让复现的人一个一个尝试?
      • 这里demo以SE模块为例子
    2. 这里的block只给出了channel不变情况,那下采样和上采样的操作呢?
      • 要么使用resnet原来block,要么自己设计。建议参考论文自己设计,因为目的是提升精度,使用前者没意义

    2.2 Feature融合

    • 这部分很简单,就是对比几个连接操作,最后选择下图 ((b)) 的内容,至于为什么?文中给了对比效果表格。
  • 相关阅读:
    Elasticsearch Network Settings
    Spring Application Event Example
    hibernate persist update 方法没有正常工作(不保存数据,不更新数据)
    快速自检电脑是否被黑客入侵过(Linux版)
    快速自检电脑是否被黑客入侵过(Windows版)
    WEB中的敏感文件泄漏
    Nginx日志分析
    关系型数据库设计小结
    软件开发的一些"心法"
    关于DNS,你应该知道这些
  • 原文地址:https://www.cnblogs.com/wjy-lulu/p/13628612.html
Copyright © 2011-2022 走看看