zoukankan      html  css  js  c++  java
  • 姿态估计openpose_pytorch_code浅析(待补充)

    接上文,经过了openpose的原理简单的解析,这一节我们主要进行code的解析。

    CODE解析
    我们主要参考的代码是https://github.com/tensorboy/pytorch_Realtime_Multi-Person_Pose_Estimation,代码写的很好,我们主要看的是demo/picture_demo.py
    首先我们看下效果,

    作图表示输入的图片,酷酷的四字弟弟,右图是出来的关键点,我们根据demo中的单张图片的前向过程讲解下模型的inference阶段,在简单的过一下训练的过程。

    而在inference的阶段中,我们主要看这几个关键的函数,我们把这几个函数扒出来单独介绍下。其中相对重要的几个函数,我们主要进行了标红,其中权重文件,在这个git里面有相关的下载连接。


    咱们主要看这这几部分
    1. model = get_model('vgg19')
    表示在上面所述的网络的示意图中,那个F使用的是vgg19,提取到的feature maps。此函数在lib/network/rtpose_vgg.py之中。

      1 """CPM Pytorch Implementation"""
      2 
      3 from collections import OrderedDict
      4 
      5 import torch
      6 import torch.nn as nn
      7 import torch.nn.functional as F
      8 import torch.utils.data as data
      9 import torch.utils.model_zoo as model_zoo
     10 from torch.autograd import Variable
     11 from torch.nn import init
     12 
     13 def make_stages(cfg_dict):
     14     """Builds CPM stages from a dictionary
     15     Args:
     16         cfg_dict: a dictionary
     17     """
     18     layers = []
     19     for i in range(len(cfg_dict) - 1):
     20         one_ = cfg_dict[i]
     21         for k, v in one_.items():
     22             if 'pool' in k:
     23                 layers += [nn.MaxPool2d(kernel_size=v[0], stride=v[1],
     24                                         padding=v[2])]
     25             else:
     26                 conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
     27                                    kernel_size=v[2], stride=v[3],
     28                                    padding=v[4])
     29                 layers += [conv2d, nn.ReLU(inplace=True)]
     30     one_ = list(cfg_dict[-1].keys())
     31     k = one_[0]
     32     v = cfg_dict[-1][k]
     33     conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
     34                        kernel_size=v[2], stride=v[3], padding=v[4])
     35     layers += [conv2d]
     36     return nn.Sequential(*layers)
     37 
     38 
     39 def make_vgg19_block(block):
     40     """Builds a vgg19 block from a dictionary
     41     Args:
     42         block: a dictionary
     43     """
     44     layers = []
     45     for i in range(len(block)):
     46         one_ = block[i]
     47         for k, v in one_.items():
     48             if 'pool' in k:
     49                 layers += [nn.MaxPool2d(kernel_size=v[0], stride=v[1],
     50                                         padding=v[2])]
     51             else:
     52                 conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1],
     53                                    kernel_size=v[2], stride=v[3],
     54                                    padding=v[4])
     55                 layers += [conv2d, nn.ReLU(inplace=True)]
     56     return nn.Sequential(*layers)
     57 
     58 
     59 
     60 def get_model(trunk='vgg19'):
     61     """Creates the whole CPM model
     62     Args:
     63         trunk: string, 'vgg19' or 'mobilenet'
     64     Returns: Module, the defined model
     65     """
     66     blocks = {}
     67     # block0 is the preprocessing stage
     68     if trunk == 'vgg19':
     69         block0 = [{'conv1_1': [3, 64, 3, 1, 1]},
     70                   {'conv1_2': [64, 64, 3, 1, 1]},
     71                   {'pool1_stage1': [2, 2, 0]},
     72                   {'conv2_1': [64, 128, 3, 1, 1]},
     73                   {'conv2_2': [128, 128, 3, 1, 1]},
     74                   {'pool2_stage1': [2, 2, 0]},
     75                   {'conv3_1': [128, 256, 3, 1, 1]},
     76                   {'conv3_2': [256, 256, 3, 1, 1]},
     77                   {'conv3_3': [256, 256, 3, 1, 1]},
     78                   {'conv3_4': [256, 256, 3, 1, 1]},
     79                   {'pool3_stage1': [2, 2, 0]},
     80                   {'conv4_1': [256, 512, 3, 1, 1]},
     81                   {'conv4_2': [512, 512, 3, 1, 1]},
     82                   {'conv4_3_CPM': [512, 256, 3, 1, 1]},
     83                   {'conv4_4_CPM': [256, 128, 3, 1, 1]}]
     84 
     85     elif trunk == 'mobilenet':
     86         block0 = [{'conv_bn': [3, 32, 2]},  # out: 3, 32, 184, 184
     87                   {'conv_dw1': [32, 64, 1]},  # out: 32, 64, 184, 184
     88                   {'conv_dw2': [64, 128, 2]},  # out: 64, 128, 92, 92
     89                   {'conv_dw3': [128, 128, 1]},  # out: 128, 256, 92, 92
     90                   {'conv_dw4': [128, 256, 2]},  # out: 256, 256, 46, 46
     91                   {'conv4_3_CPM': [256, 256, 1, 3, 1]},
     92                   {'conv4_4_CPM': [256, 128, 1, 3, 1]}]
     93 
     94     # Stage 1
     95     blocks['block1_1'] = [{'conv5_1_CPM_L1': [128, 128, 3, 1, 1]},
     96                           {'conv5_2_CPM_L1': [128, 128, 3, 1, 1]},
     97                           {'conv5_3_CPM_L1': [128, 128, 3, 1, 1]},
     98                           {'conv5_4_CPM_L1': [128, 512, 1, 1, 0]},
     99                           {'conv5_5_CPM_L1': [512, 38, 1, 1, 0]}]
    100 
    101     blocks['block1_2'] = [{'conv5_1_CPM_L2': [128, 128, 3, 1, 1]},
    102                           {'conv5_2_CPM_L2': [128, 128, 3, 1, 1]},
    103                           {'conv5_3_CPM_L2': [128, 128, 3, 1, 1]},
    104                           {'conv5_4_CPM_L2': [128, 512, 1, 1, 0]},
    105                           {'conv5_5_CPM_L2': [512, 19, 1, 1, 0]}]
    106 
    107     # Stages 2 - 6
    108     for i in range(2, 7):
    109         blocks['block%d_1' % i] = [
    110             {'Mconv1_stage%d_L1' % i: [185, 128, 7, 1, 3]},
    111             {'Mconv2_stage%d_L1' % i: [128, 128, 7, 1, 3]},
    112             {'Mconv3_stage%d_L1' % i: [128, 128, 7, 1, 3]},
    113             {'Mconv4_stage%d_L1' % i: [128, 128, 7, 1, 3]},
    114             {'Mconv5_stage%d_L1' % i: [128, 128, 7, 1, 3]},
    115             {'Mconv6_stage%d_L1' % i: [128, 128, 1, 1, 0]},
    116             {'Mconv7_stage%d_L1' % i: [128, 38, 1, 1, 0]}
    117         ]
    118 
    119         blocks['block%d_2' % i] = [
    120             {'Mconv1_stage%d_L2' % i: [185, 128, 7, 1, 3]},
    121             {'Mconv2_stage%d_L2' % i: [128, 128, 7, 1, 3]},
    122             {'Mconv3_stage%d_L2' % i: [128, 128, 7, 1, 3]},
    123             {'Mconv4_stage%d_L2' % i: [128, 128, 7, 1, 3]},
    124             {'Mconv5_stage%d_L2' % i: [128, 128, 7, 1, 3]},
    125             {'Mconv6_stage%d_L2' % i: [128, 128, 1, 1, 0]},
    126             {'Mconv7_stage%d_L2' % i: [128, 19, 1, 1, 0]}
    127         ]
    128 
    129     models = {}
    130 
    131     if trunk == 'vgg19':
    132         print("Bulding VGG19")
    133         models['block0'] = make_vgg19_block(block0)
    134 
    135     for k, v in blocks.items():
    136         models[k] = make_stages(list(v))
    137 
    138     class rtpose_model(nn.Module):
    139         def __init__(self, model_dict):
    140             super(rtpose_model, self).__init__()
    141             self.model0 = model_dict['block0']
    142             self.model1_1 = model_dict['block1_1']
    143             self.model2_1 = model_dict['block2_1']
    144             self.model3_1 = model_dict['block3_1']
    145             self.model4_1 = model_dict['block4_1']
    146             self.model5_1 = model_dict['block5_1']
    147             self.model6_1 = model_dict['block6_1']
    148 
    149             self.model1_2 = model_dict['block1_2']
    150             self.model2_2 = model_dict['block2_2']
    151             self.model3_2 = model_dict['block3_2']
    152             self.model4_2 = model_dict['block4_2']
    153             self.model5_2 = model_dict['block5_2']
    154             self.model6_2 = model_dict['block6_2']
    155 
    156             self._initialize_weights_norm()
    157 
    158         def forward(self, x):
    159 
    160             saved_for_loss = []
    161             out1 = self.model0(x)
    162 
    163             out1_1 = self.model1_1(out1)
    164             out1_2 = self.model1_2(out1)
    165             out2 = torch.cat([out1_1, out1_2, out1], 1)
    166             saved_for_loss.append(out1_1)
    167             saved_for_loss.append(out1_2)
    168 
    169             out2_1 = self.model2_1(out2)
    170             out2_2 = self.model2_2(out2)
    171             out3 = torch.cat([out2_1, out2_2, out1], 1)
    172             saved_for_loss.append(out2_1)
    173             saved_for_loss.append(out2_2)
    174 
    175             out3_1 = self.model3_1(out3)
    176             out3_2 = self.model3_2(out3)
    177             out4 = torch.cat([out3_1, out3_2, out1], 1)
    178             saved_for_loss.append(out3_1)
    179             saved_for_loss.append(out3_2)
    180 
    181             out4_1 = self.model4_1(out4)
    182             out4_2 = self.model4_2(out4)
    183             out5 = torch.cat([out4_1, out4_2, out1], 1)
    184             saved_for_loss.append(out4_1)
    185             saved_for_loss.append(out4_2)
    186 
    187             out5_1 = self.model5_1(out5)
    188             out5_2 = self.model5_2(out5)
    189             out6 = torch.cat([out5_1, out5_2, out1], 1)
    190             saved_for_loss.append(out5_1)
    191             saved_for_loss.append(out5_2)
    192 
    193             out6_1 = self.model6_1(out6)
    194             out6_2 = self.model6_2(out6)
    195             saved_for_loss.append(out6_1)
    196             saved_for_loss.append(out6_2)
    197             #其中out6_1 为38个feature maps
    198             #其中out6_2 为19个feature maps
    199             #saved_for_loss表示需要计算loss的层,对于训练的时候有用
    200             return (out6_1, out6_2), saved_for_loss
    201 
    202         def _initialize_weights_norm(self):
    203 
    204             for m in self.modules():
    205                 if isinstance(m, nn.Conv2d):
    206                     init.normal_(m.weight, std=0.01)
    207                     if m.bias is not None:  # mobilenet conv2d doesn't add bias
    208                         init.constant_(m.bias, 0.0)
    209 
    210             # last layer of these block don't have Relu
    211             init.normal_(self.model1_1[8].weight, std=0.01)
    212             init.normal_(self.model1_2[8].weight, std=0.01)
    213 
    214             init.normal_(self.model2_1[12].weight, std=0.01)
    215             init.normal_(self.model3_1[12].weight, std=0.01)
    216             init.normal_(self.model4_1[12].weight, std=0.01)
    217             init.normal_(self.model5_1[12].weight, std=0.01)
    218             init.normal_(self.model6_1[12].weight, std=0.01)
    219 
    220             init.normal_(self.model2_2[12].weight, std=0.01)
    221             init.normal_(self.model3_2[12].weight, std=0.01)
    222             init.normal_(self.model4_2[12].weight, std=0.01)
    223             init.normal_(self.model5_2[12].weight, std=0.01)
    224             init.normal_(self.model6_2[12].weight, std=0.01)
    225 
    226     model = rtpose_model(models)
    227     return model
    228 
    229 
    230 """Load pretrained model on Imagenet
    231 :param model, the PyTorch nn.Module which will train.
    232 :param model_path, the directory which load the pretrained model, will download one if not have.
    233 :param trunk, the feature extractor network of model.               
    234 """
    235 
    236 
    237 def use_vgg(model):
    238 
    239     url = 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'
    240     vgg_state_dict = model_zoo.load_url(url)
    241     vgg_keys = vgg_state_dict.keys()
    242 
    243     # load weights of vgg
    244     weights_load = {}
    245     # weight+bias,weight+bias.....(repeat 10 times)
    246     for i in range(20):
    247         weights_load[list(model.state_dict().keys())[i]
    248                      ] = vgg_state_dict[list(vgg_keys)[i]]
    249 
    250     state = model.state_dict()
    251     state.update(weights_load)
    252     model.load_state_dict(state)
    253     print('load imagenet pretrained model')
    254    


    2. get_outputs函数中,包括了模型的搭建,以及前向的工作,得到了paf以及heatmap图。咱们需要特别看下这个函数。

    def get_outputs(img, model, preprocess):
        """Computes the averaged heatmap and paf for the given image
        :param multiplier:
        :param origImg: numpy array, the image being processed
        :param model: pytorch model
        :returns: numpy arrays, the averaged paf and heatmap
        """
        inp_size = cfg.DATASET.IMAGE_SIZE
        #其中inp_size为368
        # padding
        #其中的DOWNSAMPLE为默认值是8
        im_croped, im_scale, real_shape = im_transform.crop_with_factor(
            img, inp_size, factor=cfg.MODEL.DOWNSAMPLE, is_ceil=True)
        #进行图片的处理
        #im_cropped size is => (*, 368, 3)
        if preprocess == 'rtpose':
            im_data = rtpose_preprocess(im_croped)
    
        elif preprocess == 'vgg':
            im_data = vgg_preprocess(im_croped)
    
        elif preprocess == 'inception':
            im_data = inception_preprocess(im_croped)
    
        elif preprocess == 'ssd':
            im_data = ssd_preprocess(im_croped)
    
        batch_images= np.expand_dims(im_data, 0)
    
        # several scales as a batch
        batch_var = torch.from_numpy(batch_images).cuda().float()
        #其中predicted_outputs是个tuple,后面是用来计算的loss我们用_来接住,不管它
        predicted_outputs, _ = model(batch_var)
        output1, output2 = predicted_outputs[-2], predicted_outputs[-1]
        
        heatmap = output2.cpu().data.numpy().transpose(0, 2, 3, 1)[0]
        paf = output1.cpu().data.numpy().transpose(0, 2, 3, 1)[0]
        #其中经过了8倍的下采样,(h // 8, w // 8, 38)
        #(h.//8, w//8, 19)
        return paf, heatmap, im_scale

     3. paf_to_pose_cpp函数,相对而言有点炸,里面涉及的是一个cpp代码,通过swig来进行的。大师总体的工作,,可以通过另一个代码https://blog.csdn.net/l297969586/article/details/80346254来进行对齐,后续有时间我会补上这个swig的代码,相对而言,就是通过一个采样,计算了关键点之间亲和度的方法以及关键点聚类的操作,但是里面有一些值不是很明白,还是说纯粹工程上作者们试出来的,whatever,通过这个函数,我们可以得到图片中有多少个人,以及这些人的关键点坐标,哪些关键点组成哪个人的哪个肢干等等数据,想要的数据都得到了。

    4. draw_humans 函数,进行画图,对于上面得到的结果,可以直接拿出每个人,每个人的关节点的坐标,直接进行画点,连线。。

  • 相关阅读:
    php获取随机字符串
    php短网址生成算法
    tp5.1发送邮件
    PHP简单 对象(object) 与 数组(array) 的转换
    PHP获取接下来一周的日期
    swoole 连接池
    PHP静态文件缓存
    php微信分享demo
    生成二维码并指定地址跳转
    tp5依赖注入(自动实例化):解决了像类中的方法传对象的问题
  • 原文地址:https://www.cnblogs.com/zonechen/p/11960114.html
Copyright © 2011-2022 走看看