zoukankan      html  css  js  c++  java
  • (原)人体姿态识别PyraNet

    转载请注明出处:

    https://www.cnblogs.com/darkknightzh/p/12424767.html

    论文:

    Learning Feature Pyramids for Human Pose Estimation

    https://arxiv.org/abs/1708.01101

    第三方pytorch代码:

    https://github.com/Naman-ntc/Pytorch-Human-Pose-Estimation

    1. 整体结构

    将hourglass的残差模块改为金字塔残差模块(白框),用于学习输入图像不同尺度的特征。

    hourglass见https://www.cnblogs.com/darkknightzh/p/11486185.html。参考代码中的Hourglass内部也使用了PRM模块,而不是原始的Hourglass。

    该算法在stacked hourglass的基础上更容易理解。

    2. 金字塔残差模块PRM

    论文给出了4中PRM(金字塔残差模块)的结构,最终发现PRM-B的效果最好,如下图所示。其中虚线代表同等映射,白色虚框代表该处无上采样或下采样。

    3. 下采样

    由于pooling下采样速度太快,下采样倍数最低为2,因而论文未使用pool。而是使用了fractional max-pooling的下采样方式,第c层的下采样率(论文中M=1,C=4):

    ${{s}_{c}}={{2}^{-Mfrac{c}{C}}},c=0,cdots ,C,Mge 1$

    4. 训练及测试

    训练阶段和其他姿态估计算法相似,都是估计热图,然后计算真值热图和估计热图的均方误差,如下

    $L=frac{1}{2}sumlimits_{n=1}^{N}{sumlimits_{k=1}^{K}{{{left| {{mathbf{S}}_{k}}-{{{mathbf{hat{S}}}}_{k}} ight|}^{2}}}}$

    其中N为样本数量,K为关键点的数量(也即热图数量)

    测试阶段,使用最后一个hourglass热图最大的score的位置作为关键点。由于该算法为自顶向下的姿态估计算法,输入网络的图像仅有一个人,因而最大score的位置即为对应的关键点。

    ${{mathbf{hat{z}}}_{k}}=underset{mathbf{p}}{mathop{arg max }}\,{{mathbf{hat{S}}}_{k}}(mathbf{p}),k=1,L,K$

    5. 代码

    PyraNet定义如下:

     1 class PyraNet(nn.Module):
     2     """docstring for PyraNet"""
     3     def __init__(self, nChannels=256, nStack=4, nModules=2, numReductions=4, baseWidth=6, cardinality=30, nJoints=16, inputRes=256):
     4         super(PyraNet, self).__init__()
     5         self.nChannels = nChannels
     6         self.nStack = nStack
     7         self.nModules = nModules
     8         self.numReductions = numReductions
     9         self.baseWidth = baseWidth
    10         self.cardinality = cardinality
    11         self.inputRes = inputRes
    12         self.nJoints = nJoints
    13 
    14         self.start = M.BnReluConv(3, 64, kernelSize = 7, stride = 2, padding = 3)   # BN+ReLU+conv
    15 
    16         # 先通过两分支(1*1 conv+3*3 conv,1*1 conv+不同尺度特征之和+3*3 conv,这两分支求和,并使用1*1 conv升维),并在输入输出通道相等时,直接返回,否则使用1*1 conv相加
    17         self.res1 = M.ResidualPyramid(64, 128, self.inputRes//2, self.baseWidth, self.cardinality, 0)
    18         self.mp = nn.MaxPool2d(2, 2)
    19         self.res2 = M.ResidualPyramid(128, 128, self.inputRes//4, self.baseWidth, self.cardinality,)  # 先通过两分支,并在输入输出通道相等时,直接返回,否则使用1*1 conv相加
    20         self.res3 = M.ResidualPyramid(128, self.nChannels, self.inputRes//4, self.baseWidth, self.cardinality)  # 先通过两分支,并在输入输出通道相等时,直接返回,否则使用1*1 conv相加
    21 
    22         _hourglass, _Residual, _lin1, _chantojoints, _lin2, _jointstochan = [],[],[],[],[],[]
    23 
    24         for _ in range(self.nStack):   # 堆叠个数
    25             _hourglass.append(PyraNetHourGlass(self.nChannels, self.numReductions, self.nModules, self.inputRes//4, self.baseWidth, self.cardinality))
    26             _ResidualModules = []
    27             for _ in range(self.nModules):
    28                 _ResidualModules.append(M.Residual(self.nChannels, self.nChannels))     # 输入和输出相等,只有3*(BN+ReLU+conv)
    29             _ResidualModules = nn.Sequential(*_ResidualModules)
    30             _Residual.append(_ResidualModules)
    31             _lin1.append(M.BnReluConv(self.nChannels, self.nChannels))        # BN+ReLU+conv
    32             _chantojoints.append(nn.Conv2d(self.nChannels, self.nJoints,1))   # 1*1 conv,维度变换
    33             _lin2.append(nn.Conv2d(self.nChannels, self.nChannels,1))         # 1*1 conv,维度变换
    34             _jointstochan.append(nn.Conv2d(self.nJoints,self.nChannels,1))    # 1*1 conv,维度变换
    35 
    36         self.hourglass = nn.ModuleList(_hourglass)
    37         self.Residual = nn.ModuleList(_Residual)
    38         self.lin1 = nn.ModuleList(_lin1)
    39         self.chantojoints = nn.ModuleList(_chantojoints)
    40         self.lin2 = nn.ModuleList(_lin2)
    41         self.jointstochan = nn.ModuleList(_jointstochan)
    42 
    43     def forward(self, x):
    44         x = self.start(x)
    45         x = self.res1(x)
    46         x = self.mp(x)
    47         x = self.res2(x)
    48         x = self.res3(x)
    49         out = []
    50 
    51         for i in range(self.nStack):
    52             x1 = self.hourglass[i](x)
    53             x1 = self.Residual[i](x1)
    54             x1 = self.lin1[i](x1)
    55             out.append(self.chantojoints[i](x1))
    56             x1 = self.lin2[i](x1)
    57             x = x + x1 + self.jointstochan[i](out[i])     # 特征求和
    58 
    59         return (out)
    View Code

    ResidualPyramid定义如下:

     1 class ResidualPyramid(nn.Module):
     2     """docstring for ResidualPyramid"""
     3     # 先通过两分支(1*1 conv+3*3 conv,1*1 conv+不同尺度特征之和+3*3 conv,这两分支求和,并使用1*1 conv升维),并在输入输出通道相等时,直接返回,否则使用1*1 conv相加
     4     def __init__(self, inChannels, outChannels, inputRes, baseWidth, cardinality, type = 1):
     5         super(ResidualPyramid, self).__init__()
     6         self.inChannels = inChannels
     7         self.outChannels = outChannels
     8         self.inputRes = inputRes
     9         self.baseWidth = baseWidth
    10         self.cardinality = cardinality
    11         self.type = type
    12         # PyraConvBlock:两分支,一个是1*1 conv+3*3 conv,一个是1*1 conv+不同尺度特征之和+3*3 conv,这两分支求和,并使用1*1 conv升维
    13         self.cb = PyraConvBlock(self.inChannels, self.outChannels, self.inputRes, self.baseWidth, self.cardinality, self.type)
    14         self.skip = SkipLayer(self.inChannels, self.outChannels)         # 输入和输出通道相等,则为None,否则为1*1 conv
    15 
    16     def forward(self, x):
    17         out = 0
    18         out = out + self.cb(x)
    19         out = out + self.skip(x)
    20         return out
    View Code

    PyraConvBlock如下:

     1 class PyraConvBlock(nn.Module):
     2     """docstring for PyraConvBlock"""     # 两分支,一个是1*1 conv+3*3 conv,一个是1*1 conv+不同尺度特征之和+3*3 conv,这两分支求和,并使用1*1 conv升维
     3     def __init__(self, inChannels, outChannels, inputRes, baseWidth, cardinality, type = 1):
     4         super(PyraConvBlock, self).__init__()
     5         self.inChannels = inChannels
     6         self.outChannels = outChannels
     7         self.inputRes = inputRes
     8         self.baseWidth = baseWidth
     9         self.cardinality = cardinality
    10         self.outChannelsby2 = outChannels//2
    11         self.D = self.outChannels // self.baseWidth
    12         self.branch1 = nn.Sequential(   # 第一个分支,1*1 conv + 3*3 conv
    13                 BnReluConv(self.inChannels, self.outChannelsby2, 1, 1, 0),           # BN+ReLU+conv
    14                 BnReluConv(self.outChannelsby2, self.outChannelsby2, 3, 1, 1)        # BN+ReLU+conv
    15             )
    16         self.branch2 = nn.Sequential(   # 第二个分支,1*1 conv + 3*3 conv
    17                 BnReluConv(self.inChannels, self.D, 1, 1, 0),                        # BN+ReLU+conv
    18                 BnReluPyra(self.D, self.cardinality, self.inputRes),                 # BN+ReLU+不同尺度的特征之和
    19                 BnReluConv(self.D, self.outChannelsby2, 1, 1, 0)                     # BN+ReLU+conv
    20             )
    21         self.afteradd = BnReluConv(self.outChannelsby2, self.outChannels, 1, 1, 0)   # BN+ReLU+conv
    22 
    23     def forward(self, x):
    24         x = self.branch2(x) + self.branch1(x)                                        # 两个分支特征之和
    25         x = self.afteradd(x)                                                         # 1*1 conv进行升维
    26         return x
    View Code

    BnReluPyra如下

     1 class BnReluPyra(nn.Module):
     2     """docstring for BnReluPyra"""     # BN + ReLU + 不同尺度的特征之和
     3     def __init__(self, D, cardinality, inputRes):
     4         super(BnReluPyra, self).__init__()
     5         self.D = D
     6         self.cardinality = cardinality
     7         self.inputRes = inputRes
     8         self.bn = nn.BatchNorm2d(self.D)
     9         self.relu = nn.ReLU()
    10         self.pyra = Pyramid(self.D, self.cardinality, self.inputRes)     # 将不同尺度的特征求和
    11 
    12     def forward(self, x):
    13         x = self.bn(x)
    14         x = self.relu(x)
    15         x = self.pyra(x)
    16         return x
    View Code

    Pyramid如下:

     1 class Pyramid(nn.Module):
     2     """docstring for Pyramid"""     # 将不同尺度的特征求和
     3     def __init__(self, D, cardinality, inputRes):
     4         super(Pyramid, self).__init__()
     5         self.D = D
     6         self.cardinality = cardinality     # 论文中公式3的C,金字塔层数
     7         self.inputRes = inputRes
     8         self.scale = 2**(-1/self.cardinality)   # 金字塔第1层的下采样率,后面层在此基础上+1
     9         _scales = []
    10         for card in range(self.cardinality):
    11             temp = nn.Sequential(    # 下采样 + 3*3 conv + 上采样
    12                     nn.FractionalMaxPool2d(2, output_ratio = self.scale**(card + 1)),  # 每一层在第1层基础上+1的下采样率
    13                     nn.Conv2d(self.D, self.D, 3, 1, 1),
    14                     nn.Upsample(size = self.inputRes)#, mode='bilinear')   # 上采样到输入分辨率
    15                 )
    16             _scales.append(temp)
    17         self.scales = nn.ModuleList(_scales)
    18 
    19     def forward(self, x):
    20         #print(x.shape, self.inputRes)
    21         out = torch.zeros_like(x)         # 初始化和输入大小一样的0矩阵
    22         for card in range(self.cardinality):
    23             out += self.scales[card](x)    # 将所有尺度的特征求和
    24         return out
    View Code

    PyraNetHourGlass如下:

     1 class PyraNetHourGlass(nn.Module):
     2     """docstring for PyraNetHourGlass"""
     3     def __init__(self, nChannels=256, numReductions=4, nModules=2, inputRes=256, baseWidth=6, cardinality=30, poolKernel=(2,2), poolStride=(2,2), upSampleKernel=2):
     4         super(PyraNetHourGlass, self).__init__()
     5         self.numReductions = numReductions
     6         self.nModules = nModules
     7         self.nChannels = nChannels
     8         self.poolKernel = poolKernel
     9         self.poolStride = poolStride
    10         self.upSampleKernel = upSampleKernel
    11 
    12         self.inputRes = inputRes
    13         self.baseWidth = baseWidth
    14         self.cardinality = cardinality
    15 
    16         """ For the skip connection, a residual module (or sequence of residuaql modules)  """
    17         # ResidualPyramid:先通过两分支,并在输入输出通道相等时,直接返回,否则使用1*1 conv相加
    18         # Residual:输入和输出相等,只有3*(BN+ReLU+conv)
    19         Residualskip = M.ResidualPyramid if numReductions > 1 else M.Residual
    20         Residualmain = M.ResidualPyramid if numReductions > 2 else M.Residual
    21         _skip = []
    22         for _ in range(self.nModules):  # 根据numReductions确定使用金字塔还是3*(BN+ReLU+conv)
    23             _skip.append(Residualskip(self.nChannels, self.nChannels, self.inputRes, self.baseWidth, self.cardinality))
    24         self.skip = nn.Sequential(*_skip)
    25 
    26         """ First pooling to go to smaller dimension then pass input through
    27         Residual Module or sequence of Modules then  and subsequent cases:
    28             either pass through Hourglass of numReductions-1 or pass through Residual Module or sequence of Modules """
    29         self.mp = nn.MaxPool2d(self.poolKernel, self.poolStride)
    30 
    31         _afterpool = []
    32         for _ in range(self.nModules):   # 根据numReductions确定使用金字塔还是3*(BN+ReLU+conv)
    33             _afterpool.append(Residualmain(self.nChannels, self.nChannels, self.inputRes//2, self.baseWidth, self.cardinality))
    34         self.afterpool = nn.Sequential(*_afterpool)
    35 
    36         if (numReductions > 1):     # 嵌套调用本身
    37             self.hg = PyraNetHourGlass(self.nChannels, self.numReductions-1, self.nModules, self.inputRes//2, self.baseWidth,
    38                                        self.cardinality, self.poolKernel, self.poolStride, self.upSampleKernel)
    39         else:
    40             _num1res = []
    41             for _ in range(self.nModules):    # 根据numReductions确定使用金字塔还是3*(BN+ReLU+conv)
    42                 _num1res.append(Residualmain(self.nChannels,self.nChannels, self.inputRes//2, self.baseWidth, self.cardinality))
    43             self.num1res = nn.Sequential(*_num1res)  # doesnt seem that important ?
    44 
    45         """ Now another Residual Module or sequence of Residual Modules """
    46         _lowres = []
    47         for _ in range(self.nModules):    # 根据numReductions确定使用金字塔还是3*(BN+ReLU+conv)
    48             _lowres.append(Residualmain(self.nChannels,self.nChannels, self.inputRes//2, self.baseWidth, self.cardinality))
    49         self.lowres = nn.Sequential(*_lowres)
    50 
    51         """ Upsampling Layer (Can we change this??????) As per Newell's paper upsamping recommended  """
    52         self.up = nn.Upsample(scale_factor = self.upSampleKernel)     # 将高和宽扩充,实现上采样
    53 
    54     def forward(self, x):
    55         out1 = x
    56         out1 = self.skip(out1)             # 根据numReductions确定使用金字塔还是3*(BN+ReLU+conv)
    57         out2 = x
    58         out2 = self.mp(out2)               # 根据numReductions确定使用金字塔还是3*(BN+ReLU+conv)
    59         out2 = self.afterpool(out2)
    60         if self.numReductions>1:
    61             out2 = self.hg(out2)           # 嵌套调用本身
    62         else:
    63             out2 = self.num1res(out2)      # 根据numReductions确定使用金字塔还是3*(BN+ReLU+conv)
    64         out2 = self.lowres(out2)           # 根据numReductions确定使用金字塔还是3*(BN+ReLU+conv)
    65         out2 = self.up(out2)               # 升维
    66 
    67         return out2 + out1                 # 求和
    View Code

    Residual如下:

     1 class Residual(nn.Module):
     2     """docstring for Residual"""     # 输入和输出相等,只有3*(BN+ReLU+conv);否则输入通过1*1conv结果和3*(BN+ReLU+conv)求和
     3     def __init__(self, inChannels, outChannels, inputRes=None, baseWidth=None, cardinality=None, type=None):
     4         super(Residual, self).__init__()
     5         self.inChannels = inChannels
     6         self.outChannels = outChannels
     7         self.cb = ConvBlock(self.inChannels, self.outChannels)      # 3 * (BN+ReLU+conv) 其中第一组降维,第二组不变,第三组升维
     8         self.skip = SkipLayer(self.inChannels, self.outChannels)    # 输入和输出通道相等,则为None,否则为1*1 conv
     9 
    10     def forward(self, x):
    11         out = 0
    12         out = out + self.cb(x)
    13         out = out + self.skip(x)
    14         return out
    View Code
  • 相关阅读:
    Leo程序员羊皮卷文摘(更新ing)
    ubuntu下的yuv播放器
    浏览器之一
    海量数据处理常用思路和方法(zh)
    我本将心向明月,奈何明月照沟渠
    转载光纤通信之父
    重装系统或是更换电脑之后,Foxmail的恢复
    关于录制Linux视频
    Linux之路(原发表于07年,现在搬到博客)
    Gentoo安装 miniCD+stage3
  • 原文地址:https://www.cnblogs.com/darkknightzh/p/12424767.html
Copyright © 2011-2022 走看看