zoukankan      html  css  js  c++  java
  • (原)人脸姿态时别HYBRID COARSE-FINE CLASSIFICATION FOR HEAD POSE ESTIMATION

    转载请注明出处:

    https://www.cnblogs.com/darkknightzh/p/12150128.html

    论文:

    HYBRID COARSE-FINE CLASSIFICATION FOR HEAD POSE ESTIMATION

    论文网址:

    https://arxiv.org/abs/1901.06778

    官方pytorch代码:

    https://github.com/haofanwang/accurate-head-pose

     

    该论文提出了coarse-fine的分类方式。

    1. 网络结构

    论文网络结构如下图所示。输入图像通过骨干网络得到特征后,分别连接到不同的fc层。这些fc层将输入特征映射到-99度至102度之内不同间隔的角度区间(间隔分别为1,3,11,33,99),而后通过softmax得到归一化特征,并分2支,一方面计算期望及期望与真值的MSE loss,另一方面计算交叉熵损失。而后求和,得到最终的损失。

    1)     MSE lossdeep head pose中接近(区别是此处使用198个类别的分类结果计算期望,deep head pose使用66个类别)。

    2)     其他角度区间(除198个类别的角度区间之外)只用于计算交叉熵损失(如下图所示)。

    3)     不同角度区间的交叉熵损失权重不同。

    4)     本文MSE损失的权重较大(为2

    5)     训练时使用softmax计算概率。测试时使用带temperaturesoftmax计算概率(由于代码中T=1,实际上等效于softmax)。

    6)     https://arxiv.org/abs/1503.02531可知,给定输入logit ${{z}_{i}}$,其softmax temperature的输出${{q}_{i}}$计算如下:

    ${{q}_{i}}=frac{exp ({{z}_{i}}/T)}{sum olimits_{j}{exp ({{z}_{j}}/T)}}$

    其中Ttemperature。通常设置为1(即为softmax)。T越大,输出概率的差异性越小;T越小(越接近0),输出概率的差异性越大。

    因而,感觉上图变成下面这样,会更容易理解:

    本文损失函数如下:

    $Loss=alpha centerdot MSE(y,{{y}^{*}})+sumlimits_{i=1}^{num}{{{eta }_{i}}centerdot H({{y}_{i}},y_{i}^{*})}$

    其中H代表交叉熵损失。${{eta }_{i}}$为不同角度区间时交叉熵损失的权重(具体权重可参见代码)。

    2. 代码

    2.1 网络结构

     1 class Multinet(nn.Module):
     2     # Hopenet with 3 output layers for yaw, pitch and roll
     3     # Predicts Euler angles by binning and regression with the expected value
     4     def __init__(self, block, layers, num_bins):
     5         self.inplanes = 64
     6         super(Multinet, self).__init__()
     7         self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
     8         self.bn1 = nn.BatchNorm2d(64)
     9         self.relu = nn.ReLU(inplace=True)
    10         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    11         self.layer1 = self._make_layer(block, 64, layers[0])
    12         self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
    13         self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
    14         self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
    15         self.avgpool = nn.AvgPool2d(7)    # 至此为Resnet的骨干网络
    16         self.fc_yaw = nn.Linear(512 * block.expansion, num_bins)     #  和hopenet类似,只是num_bins=198
    17         self.fc_pitch = nn.Linear(512 * block.expansion, num_bins)   #  和hopenet类似,只是num_bins=198
    18         self.fc_roll = nn.Linear(512 * block.expansion, num_bins)    #  和hopenet类似,只是num_bins=198
    19         
    20         self.fc_yaw_1 = nn.Linear(512 * block.expansion, 66)   # 66和deep head pose一致
    21         self.fc_yaw_2 = nn.Linear(512 * block.expansion, 18)   # 其他为新的fc层
    22         self.fc_yaw_3 = nn.Linear(512 * block.expansion, 6)
    23         self.fc_yaw_4 = nn.Linear(512 * block.expansion, 2)
    24         
    25         self.fc_pitch_1 = nn.Linear(512 * block.expansion, 66)
    26         self.fc_pitch_2 = nn.Linear(512 * block.expansion, 18)
    27         self.fc_pitch_3 = nn.Linear(512 * block.expansion, 6)
    28         self.fc_pitch_4 = nn.Linear(512 * block.expansion, 2)
    29         
    30         self.fc_roll_1 = nn.Linear(512 * block.expansion, 66)
    31         self.fc_roll_2 = nn.Linear(512 * block.expansion, 18)
    32         self.fc_roll_3 = nn.Linear(512 * block.expansion, 6)
    33         self.fc_roll_4 = nn.Linear(512 * block.expansion, 2)
    34 
    35         # Vestigial layer from previous experiments
    36         self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3)  # 未使用
    37 
    38         for m in self.modules():
    39             if isinstance(m, nn.Conv2d):
    40                 n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    41                 m.weight.data.normal_(0, math.sqrt(2. / n))
    42             elif isinstance(m, nn.BatchNorm2d):
    43                 m.weight.data.fill_(1)
    44                 m.bias.data.zero_()
    45 
    46     def _make_layer(self, block, planes, blocks, stride=1):
    47         downsample = None
    48         if stride != 1 or self.inplanes != planes * block.expansion:
    49             downsample = nn.Sequential(
    50                 nn.Conv2d(self.inplanes, planes * block.expansion,
    51                           kernel_size=1, stride=stride, bias=False),
    52                 nn.BatchNorm2d(planes * block.expansion),
    53             )
    54 
    55         layers = []
    56         layers.append(block(self.inplanes, planes, stride, downsample))
    57         self.inplanes = planes * block.expansion
    58         for i in range(1, blocks):
    59             layers.append(block(self.inplanes, planes))
    60 
    61         return nn.Sequential(*layers)
    62 
    63     def forward(self, x):
    64         x = self.conv1(x)
    65         x = self.bn1(x)
    66         x = self.relu(x)
    67         x = self.maxpool(x)
    68 
    69         x = self.layer1(x)
    70         x = self.layer2(x)
    71         x = self.layer3(x)
    72         x = self.layer4(x)
    73 
    74         x = self.avgpool(x)
    75         x = x.view(x.size(0), -1)  # 得到骨干网络的特征
    76         pre_yaw = self.fc_yaw(x)     # 以下得到yaw、pitch、roll等的其他特征
    77         pre_pitch = self.fc_pitch(x)
    78         pre_roll = self.fc_roll(x)
    79         
    80         pre_yaw_1 = self.fc_yaw_1(x)
    81         pre_pitch_1 = self.fc_pitch_1(x)
    82         pre_roll_1 = self.fc_roll_1(x)
    83         
    84         pre_yaw_2 = self.fc_yaw_2(x)
    85         pre_pitch_2 = self.fc_pitch_2(x)
    86         pre_roll_2 = self.fc_roll_2(x)
    87         
    88         pre_yaw_3 = self.fc_yaw_3(x)
    89         pre_pitch_3 = self.fc_pitch_3(x)
    90         pre_roll_3 = self.fc_roll_3(x)
    91         
    92         pre_yaw_4 = self.fc_yaw_4(x)
    93         pre_pitch_4 = self.fc_pitch_4(x)
    94         pre_roll_4 = self.fc_roll_4(x)
    95 
    96         return pre_yaw,pre_yaw_1,pre_yaw_2,pre_yaw_3,pre_yaw_4, pre_pitch,pre_pitch_1,pre_pitch_2,pre_pitch_3,pre_pitch_4, pre_roll,pre_roll_1,pre_roll_2,pre_roll_3,pre_roll_4
    View Code

    2.2 训练代码

      1 def parse_args():
      2     """Parse input arguments."""
      3     parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
      4     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int)
      5     parser.add_argument('--num_epochs', dest='num_epochs', help='Maximum number of training epochs.', default=25, type=int)
      6     parser.add_argument('--batch_size', dest='batch_size', help='Batch size.', default=32, type=int)
      7     parser.add_argument('--lr', dest='lr', help='Base learning rate.', default=0.000001, type=float)
      8     parser.add_argument('--dataset', dest='dataset', help='Dataset type.', default='AFLW_multi', type=str)
      9     parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.', default='', type=str)
     10     parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.', default='/tools/AFLW_train.txt', type=str)
     11     parser.add_argument('--output_string', dest='output_string', help='String appended to output snapshots.', default = '', type=str)
     12     parser.add_argument('--alpha', dest='alpha', help='Regression loss coefficient.', default=2, type=float)
     13     parser.add_argument('--snapshot', dest='snapshot', help='Path of model snapshot.', default='', type=str)
     14 
     15     args = parser.parse_args()
     16     return args
     17 
     18 def get_ignored_params(model):
     19     # Generator function that yields ignored params.
     20     b = [model.conv1, model.bn1, model.fc_finetune]
     21     for i in range(len(b)):
     22         for module_name, module in b[i].named_modules():
     23             if 'bn' in module_name:
     24                 module.eval()
     25             for name, param in module.named_parameters():
     26                 yield param
     27 
     28 def get_non_ignored_params(model):
     29     # Generator function that yields params that will be optimized.
     30     b = [model.layer1, model.layer2, model.layer3, model.layer4]
     31     for i in range(len(b)):
     32         for module_name, module in b[i].named_modules():
     33             if 'bn' in module_name:
     34                 module.eval()
     35             for name, param in module.named_parameters():
     36                 yield param
     37 
     38 def get_fc_params(model):
     39     # Generator function that yields fc layer params.
     40     b = [model.fc_yaw, model.fc_pitch, model.fc_roll,
     41          model.fc_yaw_1, model.fc_pitch_1, model.fc_roll_1,
     42          model.fc_yaw_2, model.fc_pitch_2, model.fc_roll_2,
     43          model.fc_yaw_3, model.fc_pitch_3, model.fc_roll_3]
     44     for i in range(len(b)):
     45         for module_name, module in b[i].named_modules():
     46             for name, param in module.named_parameters():
     47                 yield param
     48 
     49 def load_filtered_state_dict(model, snapshot):
     50     # By user apaszke from discuss.pytorch.org
     51     model_dict = model.state_dict()
     52     snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
     53     model_dict.update(snapshot)
     54     model.load_state_dict(model_dict)
     55 
     56 if __name__ == '__main__':
     57     args = parse_args()
     58 
     59     cudnn.enabled = True
     60     num_epochs = args.num_epochs
     61     batch_size = args.batch_size
     62     gpu = args.gpu_id
     63 
     64     if not os.path.exists('output/snapshots'):
     65         os.makedirs('output/snapshots')
     66 
     67     # ResNet50 structure
     68     model = hopenet.Multinet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 198)   # 载入模型
     69 
     70     if args.snapshot == '':
     71         load_filtered_state_dict(model, model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth'))
     72     else:
     73         saved_state_dict = torch.load(args.snapshot)
     74         model.load_state_dict(saved_state_dict)
     75 
     76     print('Loading data.')
     77 
     78     transformations = transforms.Compose([transforms.Resize(240),
     79     transforms.RandomCrop(224), transforms.ToTensor(),
     80     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
     81 
     82     if args.dataset == 'Pose_300W_LP':
     83         pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
     84     elif args.dataset == 'Pose_300W_LP_multi':
     85         pose_dataset = datasets.Pose_300W_LP_multi(args.data_dir, args.filename_list, transformations)
     86     elif args.dataset == 'Pose_300W_LP_random_ds':
     87         pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
     88     elif args.dataset == 'Synhead':
     89         pose_dataset = datasets.Synhead(args.data_dir, args.filename_list, transformations)
     90     elif args.dataset == 'AFLW2000':
     91         pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
     92     elif args.dataset == 'BIWI':
     93         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     94     elif args.dataset == 'BIWI_multi':
     95         pose_dataset = datasets.BIWI_multi(args.data_dir, args.filename_list, transformations)
     96     elif args.dataset == 'AFLW':
     97         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
     98     elif args.dataset == 'AFLW_multi':        # 载入数据的dataset
     99         pose_dataset = datasets.AFLW_multi(args.data_dir, args.filename_list, transformations)
    100     elif args.dataset == 'AFLW_aug':
    101         pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
    102     elif args.dataset == 'AFW':
    103         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
    104     else:
    105         print('Error: not a valid dataset name')
    106         sys.exit()
    107 
    108     train_loader = torch.utils.data.DataLoader(dataset=pose_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    109     
    110     model.cuda(gpu)
    111     criterion = nn.CrossEntropyLoss().cuda(gpu)
    112     reg_criterion = nn.MSELoss().cuda(gpu)
    113     # Regression loss coefficient
    114     alpha = args.alpha
    115 
    116     softmax = nn.Softmax(dim=1).cuda(gpu)
    117     idx_tensor = [idx for idx in range(198)]
    118     idx_tensor = Variable(torch.FloatTensor(idx_tensor)).cuda(gpu)
    119 
    120     optimizer = torch.optim.Adam([{'params': get_ignored_params(model), 'lr': 0},
    121                                   {'params': get_non_ignored_params(model), 'lr': args.lr},
    122                                   {'params': get_fc_params(model), 'lr': args.lr * 5}],
    123                                    lr = args.lr)
    124 
    125     print('Ready to train network.')
    126     for epoch in range(num_epochs):
    127         for i, (images, labels, labels_0, labels_1, labels_2, labels_3, cont_labels, name) in enumerate(train_loader):
    128             images = Variable(images).cuda(gpu)
    129             
    130             # Binned labels
    131             label_yaw = Variable(labels[:,0]).cuda(gpu)
    132             label_pitch = Variable(labels[:,1]).cuda(gpu)
    133             label_roll = Variable(labels[:,2]).cuda(gpu)
    134             
    135             label_yaw_1 = Variable(labels_0[:,0]).cuda(gpu)
    136             label_pitch_1 = Variable(labels_0[:,1]).cuda(gpu)
    137             label_roll_1 = Variable(labels_0[:,2]).cuda(gpu)
    138             
    139             label_yaw_2 = Variable(labels_1[:,0]).cuda(gpu)
    140             label_pitch_2 = Variable(labels_1[:,1]).cuda(gpu)
    141             label_roll_2 = Variable(labels_1[:,2]).cuda(gpu)
    142             
    143             label_yaw_3 = Variable(labels_2[:,0]).cuda(gpu)
    144             label_pitch_3 = Variable(labels_2[:,1]).cuda(gpu)
    145             label_roll_3 = Variable(labels_2[:,2]).cuda(gpu)
    146             
    147             label_yaw_4 = Variable(labels_3[:,0]).cuda(gpu)
    148             label_pitch_4 = Variable(labels_3[:,1]).cuda(gpu)
    149             label_roll_4 = Variable(labels_3[:,2]).cuda(gpu)
    150                         
    151             # Continuous labels
    152             label_yaw_cont = Variable(cont_labels[:,0]).cuda(gpu)
    153             label_pitch_cont = Variable(cont_labels[:,1]).cuda(gpu)
    154             label_roll_cont = Variable(cont_labels[:,2]).cuda(gpu)
    155 
    156             # Forward pass
    157             yaw,yaw_1,yaw_2,yaw_3,yaw_4, pitch,pitch_1,pitch_2,pitch_3,pitch_4, roll,roll_1,roll_2,roll_3,roll_4 = model(images)     # 得到各个特征
    158 
    159             # Cross entropy loss  # 各个交叉熵损失
    160             loss_yaw,loss_yaw_1,loss_yaw_2,loss_yaw_3,loss_yaw_4 = criterion(yaw, label_yaw),criterion(yaw_1, label_yaw_1),criterion(yaw_2, label_yaw_2),criterion(yaw_3, label_yaw_3),criterion(yaw_4, label_yaw_4)
    161             loss_pitch,loss_pitch_1,loss_pitch_2,loss_pitch_3,loss_pitch_4 = criterion(pitch, label_pitch),criterion(pitch_1, label_pitch_1),criterion(pitch_2, label_pitch_2),criterion(pitch_3, label_pitch_3),criterion(pitch_4, label_pitch_4)
    162             loss_roll,loss_roll_1,loss_roll_2,loss_roll_3,loss_roll_4 = criterion(roll, label_roll),criterion(roll_1, label_roll_1),criterion(roll_2, label_roll_2),criterion(roll_3, label_roll_3),criterion(roll_4, label_roll_4)
    163 
    164             # MSE loss  # 归一化特征
    165             yaw_predicted = softmax(yaw)
    166             pitch_predicted = softmax(pitch)
    167             roll_predicted = softmax(roll)
    168 
    169             yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1) - 99  # 此部分和deep head pose计算一致
    170             pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1) - 99
    171             roll_predicted = torch.sum(roll_predicted * idx_tensor, 1) - 99
    172                         
    173             loss_reg_yaw = reg_criterion(yaw_predicted, label_yaw_cont)  # 此部分和deep head pose计算一致
    174             loss_reg_pitch = reg_criterion(pitch_predicted, label_pitch_cont)
    175             loss_reg_roll = reg_criterion(roll_predicted, label_roll_cont)
    176 
    177             # Total loss
    178             total_loss_yaw = alpha * loss_reg_yaw + 7*loss_yaw + 5*loss_yaw_1 + 3*loss_yaw_2 + 1*loss_yaw_3 + 1*loss_yaw_4  # 各个角度区间的加权总损失
    179             total_loss_pitch = alpha * loss_reg_pitch + 7*loss_pitch + 5*loss_pitch_1 + 3*loss_pitch_2 + 1*loss_pitch_3 + 1*loss_pitch_4
    180             total_loss_roll = alpha * loss_reg_roll + 7*loss_roll + 5*loss_roll_1 + 3*loss_roll_2 + 1*loss_roll_3 + 1*loss_pitch_4
    181             
    182             loss_seq = [total_loss_yaw, total_loss_pitch, total_loss_roll]
    183             grad_seq = [torch.tensor(1.0).cuda(gpu) for _ in range(len(loss_seq))]
    184             optimizer.zero_grad()
    185             torch.autograd.backward(loss_seq, grad_seq)
    186             optimizer.step()
    187             
    188             if (i+1) % 100 == 0:
    189                 print ('Epoch [%d/%d], Iter [%d/%d] Losses: Yaw %.4f, Pitch %.4f, Roll %.4f'
    190                        %(epoch+1, num_epochs, i+1, len(pose_dataset)//batch_size, total_loss_yaw.item(), total_loss_pitch.item(), total_loss_roll.item()))
    191         # Save models at numbered epochs.
    192         if epoch % 1 == 0 and epoch < num_epochs:
    193             print('Taking snapshot...')
    194             torch.save(model.state_dict(),
    195             'output/snapshots/' + args.output_string + '_epoch_'+ str(epoch+1) + '.pkl')
    View Code

    2.3 测试代码

      1 def parse_args():
      2     """Parse input arguments."""
      3     parser = argparse.ArgumentParser(description='Head pose estimation using the Hopenet network.')
      4     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]', default=0, type=int)
      5     parser.add_argument('--data_dir', dest='data_dir', help='Directory path for data.', default='/AFLW2000/', type=str)
      6     parser.add_argument('--filename_list', dest='filename_list', help='Path to text file containing relative paths for every example.',
      7           default='/tools/AFLW2000_filename_filtered.txt', type=str)
      8     parser.add_argument('--snapshot', dest='snapshot', help='Name of model snapshot.',
      9           default='/output/snapshots/AFLW2000/_epoch_9.pkl', type=str)
     10     parser.add_argument('--batch_size', dest='batch_size', help='Batch size.', default=1, type=int)
     11     parser.add_argument('--save_viz', dest='save_viz', help='Save images with pose cube.', default=False, type=bool)
     12     parser.add_argument('--dataset', dest='dataset', help='Dataset type.', default='AFLW2000', type=str)
     13 
     14     args = parser.parse_args()
     15 
     16     return args
     17 
     18 if __name__ == '__main__':
     19     args = parse_args()
     20 
     21     cudnn.enabled = True
     22     gpu = args.gpu_id
     23     snapshot_path = args.snapshot
     24 
     25     # ResNet50 structure
     26     model = hopenet.Multinet(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], 198)
     27 
     28     print('Loading snapshot.')
     29     # Load snapshot
     30     saved_state_dict = torch.load(snapshot_path)
     31     model.load_state_dict(saved_state_dict)
     32 
     33     print('Loading data.')
     34 
     35     transformations = transforms.Compose([transforms.Resize(224),
     36     transforms.CenterCrop(224), transforms.ToTensor(),
     37     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
     38 
     39     if args.dataset == 'Pose_300W_LP':
     40         pose_dataset = datasets.Pose_300W_LP(args.data_dir, args.filename_list, transformations)
     41     elif args.dataset == 'Pose_300W_LP_multi':
     42         pose_dataset = datasets.Pose_300W_LP_multi(args.data_dir, args.filename_list, transformations)
     43     elif args.dataset == 'Pose_300W_LP_random_ds':
     44         pose_dataset = datasets.Pose_300W_LP_random_ds(args.data_dir, args.filename_list, transformations)
     45     elif args.dataset == 'Synhead':
     46         pose_dataset = datasets.Synhead(args.data_dir, args.filename_list, transformations)
     47     elif args.dataset == 'AFLW2000':
     48         pose_dataset = datasets.AFLW2000(args.data_dir, args.filename_list, transformations)
     49     elif args.dataset == 'BIWI':
     50         pose_dataset = datasets.BIWI(args.data_dir, args.filename_list, transformations)
     51     elif args.dataset == 'BIWI_multi':
     52         pose_dataset = datasets.BIWI_multi(args.data_dir, args.filename_list, transformations)
     53     elif args.dataset == 'AFLW':
     54         pose_dataset = datasets.AFLW(args.data_dir, args.filename_list, transformations)
     55     elif args.dataset == 'AFLW_multi':
     56         pose_dataset = datasets.AFLW_multi(args.data_dir, args.filename_list, transformations)
     57     elif args.dataset == 'AFLW_aug':
     58         pose_dataset = datasets.AFLW_aug(args.data_dir, args.filename_list, transformations)
     59     elif args.dataset == 'AFW':
     60         pose_dataset = datasets.AFW(args.data_dir, args.filename_list, transformations)
     61     else:
     62         print('Error: not a valid dataset name')
     63         sys.exit()
     64     test_loader = torch.utils.data.DataLoader(dataset=pose_dataset,
     65                                                batch_size=args.batch_size,
     66                                                num_workers=2)
     67 
     68     model.cuda(gpu)
     69 
     70     print('Ready to test network.')
     71 
     72     # Test the Model
     73     model.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
     74     total = 0
     75 
     76     idx_tensor = [idx for idx in range(198)]
     77     idx_tensor = torch.FloatTensor(idx_tensor).cuda(gpu)
     78 
     79     yaw_error = .0
     80     pitch_error = .0
     81     roll_error = .0
     82 
     83     l1loss = torch.nn.L1Loss(size_average=False)
     84     for i, (images, labels, cont_labels, name) in enumerate(test_loader):
     85     #for i, (images, labels, labels_0, labels_1, labels_2, labels_3, cont_labels, name) in enumerate(test_loader):
     86         images = Variable(images).cuda(gpu)
     87         total += cont_labels.size(0)
     88 
     89         label_yaw = cont_labels[:,0].float()
     90         label_pitch = cont_labels[:,1].float()
     91         label_roll = cont_labels[:,2].float()
     92         
     93         yaw,yaw_1,yaw_2,yaw_3,yaw_4, pitch,pitch_1,pitch_2,pitch_3,pitch_4, roll,roll_1,roll_2,roll_3,roll_4 = model(images)  # 得到特征
     94 
     95         # Binned predictions
     96         _, yaw_bpred = torch.max(yaw.data, 1)
     97         _, pitch_bpred = torch.max(pitch.data, 1)
     98         _, roll_bpred = torch.max(roll.data, 1)
     99 
    100         # Continuous predictions
    101         yaw_predicted = utils.softmax_temperature(yaw.data, 1)  # 带temperature的softmax
    102         pitch_predicted = utils.softmax_temperature(pitch.data, 1)
    103         roll_predicted = utils.softmax_temperature(roll.data, 1)
    104 
    105         yaw_predicted = torch.sum(yaw_predicted * idx_tensor, 1).cpu() - 99     # 计算期望
    106         pitch_predicted = torch.sum(pitch_predicted * idx_tensor, 1).cpu() - 99
    107         roll_predicted = torch.sum(roll_predicted * idx_tensor, 1).cpu() - 99
    108 
    109         # Mean absolute error
    110         yaw_error += torch.sum(torch.abs(yaw_predicted - label_yaw))
    111         pitch_error += torch.sum(torch.abs(pitch_predicted - label_pitch))
    112         roll_error += torch.sum(torch.abs(roll_predicted - label_roll))
    113 
    114         # Save first image in batch with pose cube or axis.
    115         if args.save_viz:
    116             name = name[0]
    117             if args.dataset == 'BIWI':
    118                 cv2_img = cv2.imread(os.path.join(args.data_dir, name + '_rgb.png'))
    119             else:
    120                 cv2_img = cv2.imread(os.path.join(args.data_dir, name + '.jpg'))
    121             if args.batch_size == 1:
    122                 error_string = 'y %.2f, p %.2f, r %.2f' % (torch.sum(torch.abs(yaw_predicted - label_yaw)), torch.sum(torch.abs(pitch_predicted - label_pitch)), torch.sum(torch.abs(roll_predicted - label_roll)))
    123                 cv2.putText(cv2_img, error_string, (30, cv2_img.shape[0]- 30), fontFace=1, fontScale=1, color=(0,0,255), thickness=2)
    124             # utils.plot_pose_cube(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], size=100)
    125             utils.draw_axis(cv2_img, yaw_predicted[0], pitch_predicted[0], roll_predicted[0], tdx = 200, tdy= 200, size=100)
    126             cv2.imwrite(os.path.join('output/images', name + '.jpg'), cv2_img)
    127 
    128     print('Test error in degrees of the model on the ' + str(total) +
    129     ' test images. Yaw: %.4f, Pitch: %.4f, Roll: %.4f, MAE: %.4f' % (yaw_error / total,
    130     pitch_error / total, roll_error / total, (yaw_error+pitch_error+roll_error)/(3.0*total)))
    View Code

    2.4 softmax_temperature代码

    1 def softmax_temperature(tensor, temperature):
    2     result = torch.exp(tensor / temperature)
    3     result = torch.div(result, torch.sum(result, 1).unsqueeze(1).expand_as(result))  # 带temperature的softmax
    4     return result
    View Code
  • 相关阅读:
    LeetCode 121. Best Time to Buy and Sell Stock
    LeetCode 221. Maximal Square
    LeetCode 152. Maximum Product Subarray
    LeetCode 53. Maximum Subarray
    LeetCode 91. Decode Ways
    LeetCode 64. Minimum Path Sum
    LeetCode 264. Ugly Number II
    LeetCode 263. Ugly Number
    LeetCode 50. Pow(x, n)
    LeetCode 279. Perfect Squares
  • 原文地址:https://www.cnblogs.com/darkknightzh/p/12150128.html
Copyright © 2011-2022 走看看