zoukankan      html  css  js  c++  java
  • 【语义分割】Stacked Hourglass Networks 以及 PyTorch 实现

    Stacked Hourglass Networks(级联漏斗网络)

    姿态估计(Pose Estimation)是 CV 领域一个非常重要的方向,而级联漏斗网络的提出就是为了提升姿态估计的效果,但是其中的经典思想可以扩展到其他方向,比如目标识别方向,代表网络是 CornerNet(预测目标的左上角和右下角点,再进行组合画框)。

    CNN 之所以有效,是因为它能自动提取出对分类、检测和识别等任务有帮助的特征,并且随着网络层数的增加,所提取的特征逐渐变得抽象。以人脸识别为例,低层卷积网络能够提取出一些简单的特征,如轮廓;中间卷积网络能够提取出抽象一些的特征,如眼睛鼻子;较高层的卷积网络则能提取出更加抽象的特征,比如完整的人脸。这些将有助于我们理解级联漏斗模型(Stacked Hourglass Model,简称SHM)为什么有效。

    做姿态估计,需要预测身体不同的关节点,手臂这种线条简单的结构,可能在中间卷积网络更容易被识别;而面部这种线条复杂的结构,可能在高层卷积网络才更容易被识别。因此,如果我们只使用最后一层的 feature map,就会造成一些信息的丢失。SHN 的主要贡献——利用多尺度特征来识别姿态。

    Single Hourglass Network

    上图是单个漏斗网络的结构。该结构与全卷积网络和其它设计(以多尺度方式处理空间信息,并进行密集预测)紧密相连。然而漏斗网络与其它设计有什么不同呢?由图可以看出,其自底向上(从高分辨率到低分辨率)处理和自顶向下(从低分辨率到高分辨率)处理之间的容量分布(这里实在不知道怎么翻译。。。)更加对称。另外还有一点需要注意,在自顶向下处理过程中,使用的不是 unpooing(一种常见的上采样操作)或者 deconv layers(可称为去卷积层),而是采用nearest neighbor upsampling(最近邻上采样)和 skip connections。这些操作需要在源码中理解。

    Stacked Hourglass Networks

    StackedHourglass_1

    上图是单个漏斗网络后面的一些设计以及两个漏斗网络的连接细节

    块1 是上面介绍的单个沙漏网络,在它后面是一个 1$ imes(1 的全卷积网络,即块2;块2 后面分离出上下两个分支(块3 和块4):上分支(块3)依然是一个 1) imes$1 的全卷积网络,下分支(块4)为 Heat map(下面重点介绍)。块5 是对块4 进行 channal 上的扩增,以方便块3、块5 和 上个漏斗网络的输出进行合并,一起作为当前漏斗网络的输出,同时是下一个漏斗网络的输入。

    这里对 Heat map 进行解释:大部分姿态检测的最后一步是对 feature map 上的每个像素做概率预测,计算该像素是某个关节点的概率,而这里的 feature map 就是上面输出的 Heat map。使用它与真值进行误差计算。应用中,如果多个 Hourglass Module 组合在一起进行梯度下降,输出层的误差经过多层反向传播会大幅减小,也就是发生了梯度消失。因此,在整个网络中每个Hourglass Module 后面都会输出 Heat map 来计算损失。这种方法称为 中间监督(Intermediate Supervision),可以保证底层参数正常更新。

    之所以使用多个 Stack Hourglass,是为了重复自下而上和自上而下的推理机制,允许重新评估整个图像的初始估计和特征,实现这一过程的核心就是预测中间的 Heat map,并让中间 Heat map 参与 loss 计算。


    PyTorch 实现 Model

    1. 首先定义残差网络的基本模块:

      HgResBlock
      import torch.nn as nn
      
      
      class HgResBlock(nn.Module):
      
          def __init__(self, inplanes, outplanes, stride=1):
              super(HgResBlock, self).__init__()
      
              self.inplanes = inplanes
              self.outplanes = outplanes
              midplanes = outplanes // 2
      
              self.bn_1 = nn.BatchNorm2d(inplanes)
              self.conv_1 = nn.Conv2d(inplanes, midplanes, kernel_size=1, stride=stride)
              self.bn_2 = nn.BatchNorm2d(midplanes)
              self.conv_2 = nn.Conv2d(midplanes, midplanes, kernel_size=3, stride=1, padding=1)
              self.bn_3 = nn.BatchNorm2d(midplanes)
              self.conv_3 = nn.Conv2d(midplanes, outplanes, kernel_size=1, stride=1)
              self.relu = nn.ReLU(inplanes=True)
              if inplanes != outplanes:
                  self.conv_skip = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1)
      
          # Bottle neck
          def forward(self, x):
              residual = x
      
              out = self.bn_1(x)
              out = self.conv_1(out)
              out = self.relu(out)
      
              out = self.bn_2(out)
              out = self.conv_2(out)
              out = self.relu(out)
      
              out = self.bn_3(out)
              out = self.conv_3(out)
              out = self.relu(out)
      
              if self.inplanes != self.outplanes:
                  residual = self.conv_skip(residual)
              out += residual
      
              return out
      
    2. 定义单个的 Hourglass Module(注意这里用到了递归):

      HourglassNetwork
      import torch.nn as nn
      
      
      class Hourglass(nn.Module):
      
          def __init__(self, depth, nFeat, nModules, resBlocks):
              super(Hourglass, self).__init__()
      
              self.depth = depth
              self.nFeat = nFeat
              self.nModules = nModules
              self.resBlocks = resBlocks
      
              self.hg = self._make_hourglass()
              self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
              self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
      
          def _make_residual(self, n):
              return nn.Sequential(*[self.resBlocks(self.nFeat, self.nFeat) for _ in range(n)])
      
          def _make_hourglass(self):
              hg = []
      
              for i in range(self.depth):
                  res = [self._make_residual(self.nModules) for _ in range(3)]
                  if i == (self.depth - 1):
                      res.append(self._make_residual(self.nModules))      # extra one for the middle
                  hg.append(nn.ModuleList(res))
      
              return nn.ModuleList(hg)
      
          def _hourglass_forward(self, depth_id, x):
              up_1 = self.hg[depth_id][0](x)
              low_1 = self.downsample(x)
              low_1 = self.hg[depth_id][1](low_1)
      
              if depth_id == (self.depth - 1):
                  low_2 = self.hg[depth_id][3](low_1)
              else:
                  low_2 = self._hourglass_forward(depth_id+1, low_1)
      
              low_3 = self.hg[depth_id][2](low_2)
              up_2 = self.upsample(low_3)
      
              return up_1 + up_2
      
          def forward(self, x):
              return self._hourglass_forward(0, x)
      
    3. 定义 Stacked Hourglass Network:

      StackedHourglass_2
      import torch.nn as nn
      
      from Model.HgResBlock import HgResBlock
      from Model.SingleHourglass import Hourglass
      
      class HourglassNet(nn.Module):
      
          def __init__(self, nStacks, nModules, nFeat, nClasses, resBlock=HgResBlock, inplanes=3):
              super(HourglassNet, self).__init__()
      
              self.nStacks = nStacks
              self.nModules = nModules
              self.nFeat = nFeat
              self.nClasses = nClasses
              self.resBlock = resBlock
              self.inplanes = inplanes
      
              hg, res, fc, score, fc_, score_ = [], [], [], [], [], []
      
              for i in range(nStacks):
                  hg.append(Hourglass(depth=4, nFeat=nFeat, nModules=nModules, resBlocks=resBlock))
                  res.append(self._make_residual(nModules))
                  fc.append(self._make_fc(nFeat, nFeat))
                  score.append(nn.Conv2d(nFeat, nClasses, kernel_size=1))
                  if i < (nStacks - 1):
                      fc_.append(nn.Conv2d(nFeat, nFeat, kernel_size=1))
                      score_.append(nn.Conv2d(nClasses, nFeat, kernel_size=1))
      
              self.hg = nn.ModuleList(hg)
              self.res = nn.ModuleList(res)
              self.fc = nn.ModuleList(fc)
              self.score = nn.ModuleList(score)
              self.fc_ = nn.ModuleList(fc_)
              self.score_ = nn.ModuleList(score_)
      
          def _make_head(self):
              self.conv_1 = nn.Conv2d(self.inplanes, 64, kernel_size=7, stride=2, padding=3)
              self.bn_1 = nn.BatchNorm2d(64)
              self.relu = nn.ReLU(inplace=True)
      
              self.res_1 = self.resBlock(64, 128)
              self.pool = nn.MaxPool2d(2, 2)
              self.res_2 = self.resBlock(128, 128)
              self.res_3 = self.resBlock(128, self.nFeat)
      
          def _make_residual(self, n):
              return nn.Sequential(*[self.resBlock(self.nFeat, self.nFeat) for _ in range(n)])
      
          def _make_fc(self, inplanes, outplanes):
              return nn.Sequential(
                  nn.Conv2d(inplanes, outplanes, kernel_size=1),
                  nn.BatchNorm2d(outplanes),
                  nn.ReLU(True))
      
          def forward(self, x):
              # head
              x = self.conv_1(x)
              x = self.bn_1(x)
              x = self.relu(x)
      
              x = self.res_1(x)
              x = self.pool(x)
              x = self.res_2(x)
              x = self.res_3(x)
      
              out = []
      
              for i in range(self.nStacks):
                  y = self.hg[i](x)
                  y = self.res[i](y)
                  y = self.fc[i](y)
                  score = self.score[i](y)
                  out.append(score)
                  if i < (self.nStacks - 1):
                      fc_ = self.fc_[i](y)
                      score_ = self.score_[i](score)
                      x = x + fc_ + score_
      
              return out
      

    References:

    ​ [1] Stacked Hourglass Networks for Human Pose Estimation

    ​ [2] [hourglass pytorch 实现]
    (https://blog.csdn.net/github_36923418/article/details/81030883)

  • 相关阅读:
    判断当前时间为星期几
    springboot+mysql数据源切换
    表单上传图片
    po,vo,bo,dto,dao解释
    生成电脑的SSH key
    单例模式
    事物的特性和隔离级别
    springAOP自定义注解讲解
    Spring依赖注入(DI)的三种方式
    redis持久化
  • 原文地址:https://www.cnblogs.com/xxxxxxxxx/p/11651437.html
Copyright © 2011-2022 走看看