zoukankan      html  css  js  c++  java
  • deeplabv3+ demo测试图像分割

    #直接复制本代码,存为.py文件,  在大概204行左右更换模型地址,在223左右更换图片路径,直接执行即可得出简单的分割效果
     11 #!--*-- coding:utf-8 --*--
     12 
     13 # Deeplab Demo
     14 
     15 import os
     16 import tarfile
     17 
     18 from matplotlib import gridspec
     19 import matplotlib.pyplot as plt
     20 import numpy as np
     21 from PIL import Image
     22 import tempfile
     23 from six.moves import urllib
     24 
     25 import tensorflow as tf
     26 
     27 
     28 class DeepLabModel(object):
     29     """
     30     加载 DeepLab 模型;
     31     推断 Inference.
     32     """
     33     INPUT_TENSOR_NAME = 'ImageTensor:0'
     34     OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
     35     INPUT_SIZE = 513
     36     FROZEN_GRAPH_NAME = 'frozen_inference_graph'
     37 
     38     def __init__(self, tarball_path):
     39         """
     40         加载预训练模型
     41         """
     42         self.graph = tf.Graph()
     43 
     44         graph_def = None
     45         # Extract frozen graph from tar archive.
     46         tar_file = tarfile.open(tarball_path)
     47         for tar_info in tar_file.getmembers():
     48             if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
     49                 file_handle = tar_file.extractfile(tar_info)
     50                 graph_def = tf.GraphDef.FromString(file_handle.read())
     51                 break
     52 
     53         tar_file.close()
     54 
     55         if graph_def is None:
     56             raise RuntimeError('Cannot find inference graph in tar archive.')
     57 
     58         with self.graph.as_default():
     59             tf.import_graph_def(graph_def, name='')
     60 
     61         self.sess = tf.Session(graph=self.graph)
     62 
     63 
     64     def run(self, image):
     65         """
     66 
     68         Args:
     69         image:  转换为PIL.Image 类,不能直接用图片,原始图片
     70 
     71         Returns:
     72         resized_image: RGB image resized from original input image.
     73         seg_map: Segmentation map of `resized_image`.
     74         """
     75         width, height = image.size
     76         resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
     77         target_size = (int(resize_ratio * width), int(resize_ratio * height))
     78         resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
     79         batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,
     80                                       feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
     81         seg_map = batch_seg_map[0]
     82         return resized_image, seg_map
     83 
     84 
     85 def create_pascal_label_colormap():
     86     """
     87     Creates a label colormap used in PASCAL VOC segmentation benchmark.
     88 
     89     Returns:
     90         A Colormap for visualizing segmentation results.
     91     """
     92     colormap = np.zeros((256, 3), dtype=int)
     93     ind = np.arange(256, dtype=int)
     94 
     95     for shift in reversed(range(8)):
     96         for channel in range(3):
     97             colormap[:, channel] |= ((ind >> channel) & 1) << shift
     98         ind >>= 3
     99 
    100     return colormap
    101 
    102 
    103 def label_to_color_image(label):
    104     """
    105     Adds color defined by the dataset colormap to the label.
    106 
    107     Args:
    108         label: A 2D array with integer type, storing the segmentation label.
    109 
    110     Returns:
    111         result: A 2D array with floating type. The element of the array
    112         is the color indexed by the corresponding element in the input label
    113         to the PASCAL color map.
    114 
    115     Raises:
    116         ValueError: If label is not of rank 2 or its value is larger than color
    117         map maximum entry.
    118     """
    119     if label.ndim != 2:
    120         raise ValueError('Expect 2-D input label')
    121 
    122     colormap = create_pascal_label_colormap()
    123 
    124     if np.max(label) >= len(colormap):
    125         raise ValueError('label value too large.')
    126 
    127     return colormap[label]
    128 
    129 
    130 def vis_segmentation(image, seg_map, imagefile):
    131     """可视化三种图像."""
    132     plt.figure(figsize=(15, 5))
    133     grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
    134 
    135     plt.subplot(grid_spec[0])
    136     plt.imshow(image)
    137     plt.axis('off')
    138     plt.title('input image')
    139 
    140     plt.subplot(grid_spec[1])
    141     seg_image = label_to_color_image(seg_map).astype(np.uint8)
    142     # seg_image = label_to_color_image(seg_map)
    143     # seg_image.save('/str(ss)+imagefile')
    144     plt.imshow(seg_image)
    145     plt.savefig('./'+imagefile+'.png')
    146 
    147     plt.axis('off')
    148     plt.title('segmentation map')
    149 
    150     plt.subplot(grid_spec[2])
    151     plt.imshow(image)
    152     plt.imshow(seg_image, alpha=0.7)
    153     plt.axis('off')
    154     plt.title('segmentation overlay')
    155 
    156     unique_labels = np.unique(seg_map)
    157     ax = plt.subplot(grid_spec[3])
    158     plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
    159     ax.yaxis.tick_right()
    160     plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    161     plt.xticks([], [])
    162     ax.tick_params(width=0.0)
    163     plt.grid('off')
    164     plt.show()
    165 
    166 
    167 ##
    168 LABEL_NAMES = np.asarray(['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    169                           'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    170                           'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv' ])
    171 
    172 FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
    173 FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
    174 
    175 
    176 ## Tensorflow 提供的模型下载
    177 MODEL_NAME = 'xception_coco_voctrainval'
    178 # ['mobilenetv2_coco_voctrainaug', 'mobilenetv2_coco_voctrainval', 'xception_coco_voctrainaug', 'xception_coco_voctrainval']
    179 
    180 _DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
    181 _MODEL_URLS = {'mobilenetv2_coco_voctrainaug': 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
    182                'mobilenetv2_coco_voctrainval': 'deeplabv3_mnv2_pascal_trainval_2018_01_29.tar.gz',
    183                'xception_coco_voctrainaug': 'deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
    184                'xception_coco_voctrainval': 'deeplabv3_pascal_trainval_2018_01_04.tar.gz', }
    185 
    186 
    187 _TARBALL_NAME = 'deeplab_model.tar.gz'
    188 
    189 # model_dir = tempfile.mkdtemp()
    190 model_dir = './'
    191 # tf.gfile.MakeDirs(model_dir)
    192 
    193 #
    194 download_path = os.path.join(model_dir, _TARBALL_NAME)
    195 print('downloading model, this might take a while...')
    196 # urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + _MODEL_URLS[MODEL_NAME], download_path)
    197 print('download completed! loading DeepLab model...')
    198 
    199 
    200 
    201 # model_dir = '/‘
    202 
    203 # download_path = os.path.join(model_dir, _MODEL_URLS[MODEL_NAME])
    204 MODEL = DeepLabModel('./deeplab_model.tar.gz')
    205 # MODEL = './deeplab_model.tar.gz'
    206 print('model loaded successfully!')
    207 
    208 
    209 ##
    210 def run_visualization(imagefile):
    211     """
    212     DeepLab 语义分割,并可视化结果.
    213     """
    214     # orignal_im = Image.open(imagefile)
    215     # print(type(orignal_im))
    216     # orignal_im.show()
    217     print('running deeplab on image %s...' % imagefile)
    218     resized_im, seg_map = MODEL.run(Image.open(imagefile))
    219 
    220 
    221     vis_segmentation(resized_im, seg_map,imagefile)
    222 
    223 images_dir = './pictures'
    224 images = sorted(os.listdir(images_dir))
    225 print(images)
    226 # img='205729y9fodss9ao6ol5921-150x150.jpg'
    227 # img.show()
    228 for imgfile in images:
    229 # img.show()
    230     run_visualization(os.path.join(images_dir, imgfile))
    231 
    232 print('Done.')

    所使用的是deeplab_model.tar.gz,也可以修改代码使用在标准数据集上预训练过的模型;代码在182行附近。

    1.修改模型保存路径

    2.修改图片路径

    3.运行即可

    参考自:https://www.aiuai.cn/aifarm252.html

  • 相关阅读:
    bzoj1618 / P2918 [USACO08NOV]买干草Buying Hay(完全背包)
    bzoj1617 / P2904 [USACO08MAR]跨河River Crossing
    bzoj1615 / P2903 [USACO08MAR]麻烦的干草打包机The Loathesome Hay Baler
    bzoj1613 / P1353 [USACO08JAN]跑步Running
    bzoj1612 / P2419 [USACO08JAN]牛大赛Cow Contest(Floyd)
    bzoj1611 / P2895 [USACO08FEB]流星雨Meteor Shower
    bzoj1610 / P2665 [USACO08FEB]连线游戏Game of Lines
    bzoj1609 / P2896 [USACO08FEB]一起吃饭Eating Together(最长不降子序列)
    bzoj1606 / P2925 [USACO08DEC]干草出售Hay For Sale(01背包)
    [bzoj1041][HAOI2008]圆上的整点
  • 原文地址:https://www.cnblogs.com/ywheunji/p/10541818.html
Copyright © 2011-2022 走看看