转载请注明出处:
https://www.cnblogs.com/darkknightzh/p/11486185.html
论文:
https://arxiv.org/abs/1603.06937
官方torch代码(没具体看):
https://github.com/princeton-vl/pose-hg-demo
第三方pytorch代码(位于models/StackedHourGlass.py):
https://github.com/Naman-ntc/Pytorch-Human-Pose-Estimation
1. 简介
该论文利用多尺度特征来识别姿态,如下图所示,每个子网络称为hourglass Network,是一个沙漏型的结构,多个这种结构堆叠起来,称作stacked hourglass。堆叠的方式,方便每个模块在整个图像上重新估计姿态和特征。如下图所示,输入图像通过全卷积网络fcn后,得到特征,而后通过多个堆叠的hourglass,得到最终的热图。
Hourglass如下图所示。其中每个方块均为下下图的残差模块。
Hourglass采用了中间监督(Intermediate Supervision)。每个hourglass均会有热图(蓝色)。训练阶段,将这些热图和真实热图计算损失MSE,并求和,得到损失;推断阶段,使用的是最后一个hourglass的热图。
2. stacked hourglass
堆叠hourglass结构如下图所示(nChannels=256,nStack=2,nModules=2,numReductions=4, nJoints=17):
代码如下:

1 class StackedHourGlass(nn.Module): 2 """docstring for StackedHourGlass""" 3 def __init__(self, nChannels, nStack, nModules, numReductions, nJoints): 4 super(StackedHourGlass, self).__init__() 5 self.nChannels = nChannels 6 self.nStack = nStack 7 self.nModules = nModules 8 self.numReductions = numReductions 9 self.nJoints = nJoints 10 11 self.start = M.BnReluConv(3, 64, kernelSize = 7, stride = 2, padding = 3) # BN+ReLU+conv 12 13 self.res1 = M.Residual(64, 128) # 输入和输出不等,输入通过1*1conv结果和3*(BN+ReLU+conv)求和 14 self.mp = nn.MaxPool2d(2, 2) 15 self.res2 = M.Residual(128, 128) # 输入和输出相等,为x+3*(BN+ReLU+conv) 16 self.res3 = M.Residual(128, self.nChannels) # 输入和输出相等,为x+3*(BN+ReLU+conv);否则输入通过1*1conv结果和3*(BN+ReLU+conv)求和。 17 18 _hourglass, _Residual, _lin1, _chantojoints, _lin2, _jointstochan = [],[],[],[],[],[] 19 20 for _ in range(self.nStack): # 堆叠个数 21 _hourglass.append(Hourglass(self.nChannels, self.numReductions, self.nModules)) 22 _ResidualModules = [] 23 for _ in range(self.nModules): 24 _ResidualModules.append(M.Residual(self.nChannels, self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv) 25 _ResidualModules = nn.Sequential(*_ResidualModules) 26 _Residual.append(_ResidualModules) # self.nModules 个 3*(BN+ReLU+conv) 27 _lin1.append(M.BnReluConv(self.nChannels, self.nChannels)) # BN+ReLU+conv 28 _chantojoints.append(nn.Conv2d(self.nChannels, self.nJoints,1)) # 1*1 conv,维度变换 29 _lin2.append(nn.Conv2d(self.nChannels, self.nChannels,1)) # 1*1 conv,维度不变 30 _jointstochan.append(nn.Conv2d(self.nJoints,self.nChannels,1)) # 1*1 conv,维度变换 31 32 self.hourglass = nn.ModuleList(_hourglass) 33 self.Residual = nn.ModuleList(_Residual) 34 self.lin1 = nn.ModuleList(_lin1) 35 self.chantojoints = nn.ModuleList(_chantojoints) 36 self.lin2 = nn.ModuleList(_lin2) 37 self.jointstochan = nn.ModuleList(_jointstochan) 38 39 def forward(self, x): 40 x = self.start(x) 41 x = self.res1(x) 42 x = self.mp(x) 43 x = self.res2(x) 44 x = self.res3(x) 45 out = [] 46 47 for i in range(self.nStack): 48 x1 = self.hourglass[i](x) 49 x1 = self.Residual[i](x1) 50 x1 = self.lin1[i](x1) 51 out.append(self.chantojoints[i](x1)) 52 x1 = self.lin2[i](x1) 53 x = x + x1 + self.jointstochan[i](out[i]) # 特征求和 54 55 return (out)
3. hourglass
hourglass在numReductions>1时,递归调用自己,结构如下:
代码如下:

1 class Hourglass(nn.Module): 2 """docstring for Hourglass""" 3 def __init__(self, nChannels = 256, numReductions = 4, nModules = 2, poolKernel = (2,2), poolStride = (2,2), upSampleKernel = 2): 4 super(Hourglass, self).__init__() 5 self.numReductions = numReductions 6 self.nModules = nModules 7 self.nChannels = nChannels 8 self.poolKernel = poolKernel 9 self.poolStride = poolStride 10 self.upSampleKernel = upSampleKernel 11 12 """For the skip connection, a residual module (or sequence of residuaql modules) """ 13 _skip = [] 14 for _ in range(self.nModules): 15 _skip.append(M.Residual(self.nChannels, self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv) 16 self.skip = nn.Sequential(*_skip) 17 18 """First pooling to go to smaller dimension then pass input through 19 Residual Module or sequence of Modules then and subsequent cases: 20 either pass through Hourglass of numReductions-1 or pass through M.Residual Module or sequence of Modules """ 21 self.mp = nn.MaxPool2d(self.poolKernel, self.poolStride) 22 23 _afterpool = [] 24 for _ in range(self.nModules): 25 _afterpool.append(M.Residual(self.nChannels, self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv) 26 self.afterpool = nn.Sequential(*_afterpool) 27 28 if (numReductions > 1): 29 self.hg = Hourglass(self.nChannels, self.numReductions-1, self.nModules, self.poolKernel, self.poolStride) # 嵌套调用本身 30 else: 31 _num1res = [] 32 for _ in range(self.nModules): 33 _num1res.append(M.Residual(self.nChannels,self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv) 34 self.num1res = nn.Sequential(*_num1res) # doesnt seem that important ? 35 36 """ Now another M.Residual Module or sequence of M.Residual Modules """ 37 _lowres = [] 38 for _ in range(self.nModules): 39 _lowres.append(M.Residual(self.nChannels,self.nChannels)) # 输入和输出相等,为x+3*(BN+ReLU+conv) 40 self.lowres = nn.Sequential(*_lowres) 41 42 """ Upsampling Layer (Can we change this??????) As per Newell's paper upsamping recommended """ 43 self.up = myUpsample()#nn.Upsample(scale_factor = self.upSampleKernel) # 将高和宽扩充为原来2倍,实现上采样 44 45 46 def forward(self, x): 47 out1 = x 48 out1 = self.skip(out1) # 输入和输出相等,为x+3*(BN+ReLU+conv) 49 out2 = x 50 out2 = self.mp(out2) # 降维 51 out2 = self.afterpool(out2) # 输入和输出相等,为x+3*(BN+ReLU+conv) 52 if self.numReductions>1: 53 out2 = self.hg(out2) # 嵌套调用本身 54 else: 55 out2 = self.num1res(out2) # 输入和输出相等,为x+3*(BN+ReLU+conv) 56 out2 = self.lowres(out2) # 输入和输出相等,为x+3*(BN+ReLU+conv) 57 out2 = self.up(out2) # 升维 58 59 return out2 + out1 # 求和
4. 上采样myUpsample
上采样代码如下:

1 class myUpsample(nn.Module): 2 def __init__(self): 3 super(myUpsample, self).__init__() 4 pass 5 def forward(self, x): # 将高和宽扩充为原来2倍,实现上采样 6 return x[:, :, :, None, :, None].expand(-1, -1, -1, 2, -1, 2).reshape(x.size(0), x.size(1), x.size(2)*2, x.size(3)*2)
其中x为(N)(C)(H)(W)的矩阵,x[:, :, :, None, :, None]为(N)(C)(H)(1)(W)(1)的矩阵,expand之后变成(N)(C)(H)(2)(W)(2)的矩阵,最终reshape之后变成(N)(C)(2H) (2W)的矩阵,实现了将1个像素水平和垂直方向各扩充2倍,变成4个像素(4个像素值相同),完成了上采样。
5. 残差模块
残差模块结构如下:
代码如下:

1 class Residual(nn.Module): 2 """docstring for Residual""" # 输入和输出相等,为x+3*(BN+ReLU+conv);否则输入通过1*1conv结果和3*(BN+ReLU+conv)求和 3 def __init__(self, inChannels, outChannels): 4 super(Residual, self).__init__() 5 self.inChannels = inChannels 6 self.outChannels = outChannels 7 self.cb = ConvBlock(inChannels, outChannels) # 3 * (BN+ReLU+conv) 其中第一组降维,第二组不变,第三组升维 8 self.skip = SkipLayer(inChannels, outChannels) # 输入和输出通道相等,则输出=输入,否则为1*1 conv 9 10 def forward(self, x): 11 out = 0 12 out = out + self.cb(x) 13 out = out + self.skip(x) 14 return out
其中skiplayer代码如下:

1 class SkipLayer(nn.Module): 2 """docstring for SkipLayer""" # 输入和输出通道相等,则输出=输入,否则为1*1 conv 3 def __init__(self, inChannels, outChannels): 4 super(SkipLayer, self).__init__() 5 self.inChannels = inChannels 6 self.outChannels = outChannels 7 if (self.inChannels == self.outChannels): 8 self.conv = None 9 else: 10 self.conv = nn.Conv2d(self.inChannels, self.outChannels, 1) 11 12 def forward(self, x): 13 if self.conv is not None: 14 x = self.conv(x) 15 return x
6. conv

1 class BnReluConv(nn.Module): 2 """docstring for BnReluConv""" # BN+ReLU+conv 3 def __init__(self, inChannels, outChannels, kernelSize = 1, stride = 1, padding = 0): 4 super(BnReluConv, self).__init__() 5 self.inChannels = inChannels 6 self.outChannels = outChannels 7 self.kernelSize = kernelSize 8 self.stride = stride 9 self.padding = padding 10 11 self.bn = nn.BatchNorm2d(self.inChannels) 12 self.conv = nn.Conv2d(self.inChannels, self.outChannels, self.kernelSize, self.stride, self.padding) 13 self.relu = nn.ReLU() 14 15 def forward(self, x): 16 x = self.bn(x) 17 x = self.relu(x) 18 x = self.conv(x) 19 return x
7. ConvBlock

1 class ConvBlock(nn.Module): 2 """docstring for ConvBlock""" # 3 * (BN+ReLU+conv) 其中第一组降维,第二组不变,第三组升维 3 def __init__(self, inChannels, outChannels): 4 super(ConvBlock, self).__init__() 5 self.inChannels = inChannels 6 self.outChannels = outChannels 7 self.outChannelsby2 = outChannels//2 8 9 self.cbr1 = BnReluConv(self.inChannels, self.outChannelsby2, 1, 1, 0) # BN+ReLU+conv 10 self.cbr2 = BnReluConv(self.outChannelsby2, self.outChannelsby2, 3, 1, 1) # BN+ReLU+conv 11 self.cbr3 = BnReluConv(self.outChannelsby2, self.outChannels, 1, 1, 0) # BN+ReLU+conv 12 13 def forward(self, x): 14 x = self.cbr1(x) 15 x = self.cbr2(x) 16 x = self.cbr3(x) 17 return x