zoukankan      html  css  js  c++  java
  • 从语义分割模型得到彩色的分割图过程

    将训练好的语义分割模型保存下来,重新加载之后

    通过这一个操作得到标签;

    output = self.model(image)
    这里的output即为标签内容,通过重新编码的函数来获得彩色图像.
     1 def decode_segmap(label_mask, dataset, plot=False):
     2     """Decode segmentation class labels into a color image
     3     解码标签,得到彩色的图像
     4     Args:
     5         label_mask (np.ndarray): an (M,N) array of integer values denoting
     6           the class label at each spatial location.
     7         plot (bool, optional): whether to show the resulting color image
     8           in a figure.
     9     Returns:
    10         (np.ndarray, optional): the resulting decoded color image.
    11     """
    12     if dataset == 'pascal' or dataset == 'coco':
    13         n_classes = 21
    14         label_colours = get_pascal_labels()
    15     elif dataset == 'cityscapes':
    16         n_classes = 19
    17         label_colours = get_cityscapes_labels()
    18     else:
    19         raise NotImplementedError
    20 
    21     r = label_mask.copy()
    22     g = label_mask.copy()
    23     b = label_mask.copy()
    24     for ll in range(0, n_classes):
    25         r[label_mask == ll] = label_colours[ll, 0]
    26         g[label_mask == ll] = label_colours[ll, 1]
    27         b[label_mask == ll] = label_colours[ll, 2]
    28     rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3))
    29     rgb[:, :, 0] = r / 255.0
    30     rgb[:, :, 1] = g / 255.0
    31     rgb[:, :, 2] = b / 255.0
    32     if plot:
    33         plt.imshow(rgb)
    34         plt.show()
    35     else:
    36         return rgb

    绘图的主函数在下面:

     1 if __name__ == '__main__':
     2     from dataloaders.utils import decode_segmap
     3     from torch.utils.data import DataLoader
     4     import matplotlib.pyplot as plt
     5     import argparse
     6 
     7     parser = argparse.ArgumentParser()
     8     args = parser.parse_args()
     9     args.base_size = 256
    10     args.crop_size = 256
    11 
    12     voc_train = VOCSegmentation(args, split='train')
    13 
    14     dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0)
    15 
    16 
    17 
    18     for ii, sample in enumerate(dataloader):
    19         for jj in range(sample["image"].size()[0]):
    20             img = sample['image'].numpy()
    21             gt = sample['label'].numpy()
    22             tmp = np.array(gt[jj]).astype(np.uint8)
    23             segmap = decode_segmap(tmp, dataset='pascal')
    24             img_tmp = np.transpose(img[jj], axes=[1, 2, 0])
    25             img_tmp *= (0.229, 0.224, 0.225)
    26             img_tmp += (0.485, 0.456, 0.406)
    27             img_tmp *= 255.0
    28             img_tmp = img_tmp.astype(np.uint8)
    29             plt.figure()
    30             plt.title('display')
    31             plt.subplot(211)
    32             plt.imshow(img_tmp)
    33             plt.subplot(212)
    34             plt.imshow(segmap)
    35 
    36         if ii == 1:
    37             break
    38 
    39     plt.show(block=True)
  • 相关阅读:
    入门教程: JS认证和WebAPI
    ASP.NET Core 之 Identity 入门(二)
    在Visual Studio 2017中使用Asp.Net Core构建Angular4应用程序
    .Net Core+Angular Cli/Angular4开发环境搭建教程
    简单易用的.NET免费开源RabbitMQ操作组件EasyNetQ解析
    Razor
    一个简易的反射类库NMSReflector
    发布 Ionic iOS 企业级应用
    AngularJS中的Provider们:Service和Factory等的区别
    Linux企业运维人员必备150个命令汇总
  • 原文地址:https://www.cnblogs.com/ywheunji/p/10711150.html
Copyright © 2011-2022 走看看