zoukankan      html  css  js  c++  java
  • (原)堆叠hourglass网络

    转载请注明出处:

    https://www.cnblogs.com/darkknightzh/p/11486185.html

    论文:

    https://arxiv.org/abs/1603.06937

    官方torch代码(没具体看):

    https://github.com/princeton-vl/pose-hg-demo

    第三方pytorch代码(位于models/StackedHourGlass.py):

    https://github.com/Naman-ntc/Pytorch-Human-Pose-Estimation

    1. 简介

    该论文利用多尺度特征来识别姿态,如下图所示,每个子网络称为hourglass Network,是一个沙漏型的结构,多个这种结构堆叠起来,称作stacked hourglass。堆叠的方式,方便每个模块在整个图像上重新估计姿态和特征。如下图所示,输入图像通过全卷积网络fcn后,得到特征,而后通过多个堆叠的hourglass,得到最终的热图。

    Hourglass如下图所示。其中每个方块均为下下图的残差模块。

    Hourglass采用了中间监督(Intermediate Supervision)。每个hourglass均会有热图(蓝色)。训练阶段,将这些热图和真实热图计算损失MSE,并求和,得到损失;推断阶段,使用的是最后一个hourglass的热图。

    2. stacked hourglass

    堆叠hourglass结构如下图所示(nChannels=256,nStack=2,nModules=2,numReductions=4, nJoints=17):

    代码如下:

     1 class StackedHourGlass(nn.Module):
     2     """docstring for StackedHourGlass"""
     3     def __init__(self, nChannels, nStack, nModules, numReductions, nJoints):
     4         super(StackedHourGlass, self).__init__()
     5         self.nChannels = nChannels
     6         self.nStack = nStack
     7         self.nModules = nModules
     8         self.numReductions = numReductions
     9         self.nJoints = nJoints
    10 
    11         self.start = M.BnReluConv(3, 64, kernelSize = 7, stride = 2, padding = 3)  # BN+ReLU+conv
    12 
    13         self.res1 = M.Residual(64, 128) # 输入和输出不等,输入通过1*1conv结果和3*(BN+ReLU+conv)求和
    14         self.mp = nn.MaxPool2d(2, 2)
    15         self.res2 = M.Residual(128, 128) # 输入和输出相等,为x+3*(BN+ReLU+conv)
    16         self.res3 = M.Residual(128, self.nChannels) # 输入和输出相等,为x+3*(BN+ReLU+conv);否则输入通过1*1conv结果和3*(BN+ReLU+conv)求和。
    17 
    18         _hourglass, _Residual, _lin1, _chantojoints, _lin2, _jointstochan = [],[],[],[],[],[]
    19 
    20         for _ in range(self.nStack):  # 堆叠个数
    21             _hourglass.append(Hourglass(self.nChannels, self.numReductions, self.nModules))
    22             _ResidualModules = []
    23             for _ in range(self.nModules):
    24                 _ResidualModules.append(M.Residual(self.nChannels, self.nChannels))   # 输入和输出相等,为x+3*(BN+ReLU+conv)
    25             _ResidualModules = nn.Sequential(*_ResidualModules)
    26             _Residual.append(_ResidualModules)   # self.nModules 个 3*(BN+ReLU+conv)
    27             _lin1.append(M.BnReluConv(self.nChannels, self.nChannels))       # BN+ReLU+conv
    28             _chantojoints.append(nn.Conv2d(self.nChannels, self.nJoints,1))  # 1*1 conv,维度变换
    29             _lin2.append(nn.Conv2d(self.nChannels, self.nChannels,1))        # 1*1 conv,维度不变
    30             _jointstochan.append(nn.Conv2d(self.nJoints,self.nChannels,1))   # 1*1 conv,维度变换
    31 
    32         self.hourglass = nn.ModuleList(_hourglass)
    33         self.Residual = nn.ModuleList(_Residual)
    34         self.lin1 = nn.ModuleList(_lin1)
    35         self.chantojoints = nn.ModuleList(_chantojoints)
    36         self.lin2 = nn.ModuleList(_lin2)
    37         self.jointstochan = nn.ModuleList(_jointstochan)
    38 
    39     def forward(self, x):
    40         x = self.start(x)
    41         x = self.res1(x)
    42         x = self.mp(x)
    43         x = self.res2(x)
    44         x = self.res3(x)
    45         out = []
    46 
    47         for i in range(self.nStack):
    48             x1 = self.hourglass[i](x)
    49             x1 = self.Residual[i](x1)
    50             x1 = self.lin1[i](x1)
    51             out.append(self.chantojoints[i](x1))
    52             x1 = self.lin2[i](x1)
    53             x = x + x1 + self.jointstochan[i](out[i])   # 特征求和
    54 
    55         return (out)
    View Code

    3. hourglass

    hourglass在numReductions>1时,递归调用自己,结构如下:

    代码如下:

     1 class Hourglass(nn.Module):
     2     """docstring for Hourglass"""
     3     def __init__(self, nChannels = 256, numReductions = 4, nModules = 2, poolKernel = (2,2), poolStride = (2,2), upSampleKernel = 2):
     4         super(Hourglass, self).__init__()
     5         self.numReductions = numReductions
     6         self.nModules = nModules
     7         self.nChannels = nChannels
     8         self.poolKernel = poolKernel
     9         self.poolStride = poolStride
    10         self.upSampleKernel = upSampleKernel
    11 
    12         """For the skip connection, a residual module (or sequence of residuaql modules)  """
    13         _skip = []
    14         for _ in range(self.nModules):
    15             _skip.append(M.Residual(self.nChannels, self.nChannels))  # 输入和输出相等,为x+3*(BN+ReLU+conv)
    16         self.skip = nn.Sequential(*_skip)
    17 
    18         """First pooling to go to smaller dimension then pass input through
    19         Residual Module or sequence of Modules then  and subsequent cases:
    20             either pass through Hourglass of numReductions-1 or pass through M.Residual Module or sequence of Modules """
    21         self.mp = nn.MaxPool2d(self.poolKernel, self.poolStride)
    22 
    23         _afterpool = []
    24         for _ in range(self.nModules):
    25             _afterpool.append(M.Residual(self.nChannels, self.nChannels))  # 输入和输出相等,为x+3*(BN+ReLU+conv)
    26         self.afterpool = nn.Sequential(*_afterpool)
    27 
    28         if (numReductions > 1):
    29             self.hg = Hourglass(self.nChannels, self.numReductions-1, self.nModules, self.poolKernel, self.poolStride)  # 嵌套调用本身
    30         else:
    31             _num1res = []
    32             for _ in range(self.nModules):
    33                 _num1res.append(M.Residual(self.nChannels,self.nChannels))  # 输入和输出相等,为x+3*(BN+ReLU+conv)
    34             self.num1res = nn.Sequential(*_num1res)  # doesnt seem that important ?
    35 
    36         """ Now another M.Residual Module or sequence of M.Residual Modules  """
    37         _lowres = []
    38         for _ in range(self.nModules):
    39             _lowres.append(M.Residual(self.nChannels,self.nChannels))   # 输入和输出相等,为x+3*(BN+ReLU+conv)
    40         self.lowres = nn.Sequential(*_lowres)
    41 
    42         """ Upsampling Layer (Can we change this??????) As per Newell's paper upsamping recommended  """
    43         self.up = myUpsample()#nn.Upsample(scale_factor = self.upSampleKernel)   # 将高和宽扩充为原来2倍,实现上采样
    44 
    45 
    46     def forward(self, x):
    47         out1 = x
    48         out1 = self.skip(out1)          # 输入和输出相等,为x+3*(BN+ReLU+conv)
    49         out2 = x
    50         out2 = self.mp(out2)            # 降维
    51         out2 = self.afterpool(out2)     # 输入和输出相等,为x+3*(BN+ReLU+conv)
    52         if self.numReductions>1:
    53             out2 = self.hg(out2)        # 嵌套调用本身
    54         else:
    55             out2 = self.num1res(out2)   # 输入和输出相等,为x+3*(BN+ReLU+conv)
    56         out2 = self.lowres(out2)        # 输入和输出相等,为x+3*(BN+ReLU+conv)
    57         out2 = self.up(out2)            # 升维
    58 
    59         return out2 + out1              # 求和
    View Code

    4. 上采样myUpsample

    上采样代码如下:

    1 class myUpsample(nn.Module):
    2     def __init__(self):
    3         super(myUpsample, self).__init__()
    4         pass
    5     def forward(self, x):   # 将高和宽扩充为原来2倍,实现上采样
    6         return x[:, :, :, None, :, None].expand(-1, -1, -1, 2, -1, 2).reshape(x.size(0), x.size(1), x.size(2)*2, x.size(3)*2)
    View Code

    其中x为(N)(C)(H)(W)的矩阵,x[:, :, :, None, :, None]为(N)(C)(H)(1)(W)(1)的矩阵,expand之后变成(N)(C)(H)(2)(W)(2)的矩阵,最终reshape之后变成(N)(C)(2H) (2W)的矩阵,实现了将1个像素水平和垂直方向各扩充2倍,变成4个像素(4个像素值相同),完成了上采样。

    5. 残差模块

    残差模块结构如下:

    代码如下:

     1 class Residual(nn.Module):
     2         """docstring for Residual"""  # 输入和输出相等,为x+3*(BN+ReLU+conv);否则输入通过1*1conv结果和3*(BN+ReLU+conv)求和
     3         def __init__(self, inChannels, outChannels):
     4                 super(Residual, self).__init__()
     5                 self.inChannels = inChannels
     6                 self.outChannels = outChannels
     7                 self.cb = ConvBlock(inChannels, outChannels)      # 3 * (BN+ReLU+conv) 其中第一组降维,第二组不变,第三组升维
     8                 self.skip = SkipLayer(inChannels, outChannels)    # 输入和输出通道相等,则输出=输入,否则为1*1 conv
     9 
    10         def forward(self, x):
    11                 out = 0
    12                 out = out + self.cb(x)
    13                 out = out + self.skip(x)
    14                 return out
    View Code

    其中skiplayer代码如下:

     1 class SkipLayer(nn.Module):
     2         """docstring for SkipLayer"""  # 输入和输出通道相等,则输出=输入,否则为1*1 conv
     3         def __init__(self, inChannels, outChannels):
     4                 super(SkipLayer, self).__init__()
     5                 self.inChannels = inChannels
     6                 self.outChannels = outChannels
     7                 if (self.inChannels == self.outChannels):
     8                         self.conv = None
     9                 else:
    10                         self.conv = nn.Conv2d(self.inChannels, self.outChannels, 1)
    11 
    12         def forward(self, x):
    13                 if self.conv is not None:
    14                         x = self.conv(x)
    15                 return x
    View Code

    6. conv

     1 class BnReluConv(nn.Module):
     2         """docstring for BnReluConv"""    # BN+ReLU+conv
     3         def __init__(self, inChannels, outChannels, kernelSize = 1, stride = 1, padding = 0):
     4                 super(BnReluConv, self).__init__()
     5                 self.inChannels = inChannels
     6                 self.outChannels = outChannels
     7                 self.kernelSize = kernelSize
     8                 self.stride = stride
     9                 self.padding = padding
    10 
    11                 self.bn = nn.BatchNorm2d(self.inChannels)
    12                 self.conv = nn.Conv2d(self.inChannels, self.outChannels, self.kernelSize, self.stride, self.padding)
    13                 self.relu = nn.ReLU()
    14 
    15         def forward(self, x):
    16                 x = self.bn(x)
    17                 x = self.relu(x)
    18                 x = self.conv(x)
    19                 return x
    View Code

    7. ConvBlock

     1 class ConvBlock(nn.Module):
     2         """docstring for ConvBlock"""  # 3 * (BN+ReLU+conv) 其中第一组降维,第二组不变,第三组升维
     3         def __init__(self, inChannels, outChannels):
     4                 super(ConvBlock, self).__init__()
     5                 self.inChannels = inChannels
     6                 self.outChannels = outChannels
     7                 self.outChannelsby2 = outChannels//2
     8 
     9                 self.cbr1 = BnReluConv(self.inChannels, self.outChannelsby2, 1, 1, 0)        # BN+ReLU+conv
    10                 self.cbr2 = BnReluConv(self.outChannelsby2, self.outChannelsby2, 3, 1, 1)    # BN+ReLU+conv
    11                 self.cbr3 = BnReluConv(self.outChannelsby2, self.outChannels, 1, 1, 0)       # BN+ReLU+conv
    12 
    13         def forward(self, x):
    14                 x = self.cbr1(x)
    15                 x = self.cbr2(x)
    16                 x = self.cbr3(x)
    17                 return x
    View Code
  • 相关阅读:
    关于word开发中字体大小
    WPF学习笔记
    C#各种配置文件使用,操作方法总结
    web.config和app.config使用
    微软 WordXML格式初步分析
    面向对象—C#高级编程(第10版)学习笔记8
    C#编程的推荐规则和约定—C#高级编程(第10版)学习笔记7
    C#基础—C#高级编程(第10版)学习笔记6
    .Net 应用程序体系结构—C#高级编程(第10版)学习笔记5
    通俗易懂说编程:.Net Core是什么、有何用?
  • 原文地址:https://www.cnblogs.com/darkknightzh/p/11486185.html
Copyright © 2011-2022 走看看