zoukankan      html  css  js  c++  java
  • DeeplabV3+训练自己的数据集(三)

    模型训练及测试

    一、在DeepLabv3+模型的基础上,主要需要修改以下两个文件

       data_generator.py

       train_utils.py

       (1)添加数据集描述

       在datasets/data_generator.py文件中,添加自己的数据集描述:
    _CAMVID_INFORMATION = DatasetDescriptor(
        splits_to_sizes={
        'train': 1035,
        'val': 31,},
        num_classes=3,
        ignore_label=255, )
    自己的数据集共有3个classes,算上了background。由于没有使用 ignore_label , 没有算上ignore_label

      (2)注册数据集

    _DATASETS_INFORMATION = {
        'cityscapes': _CITYSCAPES_INFORMATION,
        'pascal_voc_seg': _PASCAL_VOC_SEG_INFORMATION,
        'ade20k': _ADE20K_INFORMATION,
        'camvid':_CAMVID_INFORMATION,
        # 'mydata':_MYDATA_INFORMATION,
        }

      (3)修改train_utils.py 

      对应的utils/train_utils.py中,将210行关于 exclude_list 的设置修改,作用是在使用预训练权重时候,不加载该 logit 层:

      

    exclude_list = ['global_step','logits']
    if not initialize_last_layer:
        exclude_list.extend(last_layers)

      如果想在DeepLab的基础上fifine-tune其他数据集, 可在deeplab/train.py中修改输入参数。

      一些选项:
        使用预训练的所有权重,设置initialize_last_layer=True
        只使用网络的backbone,设置initialize_last_layer=False和
        last_layers_contain_logits_only=False
        使用所有的预训练权重,除了logits。因为如果是自己的数据集,对应的classes不同(这个我们前面已经设置不加载logits),可设置initialize_last_layer=False和ast_layers_contain_logits_only=True
      这里使用的设置是:
      initialize_last_layer=False #157行
      last_layers_contain_logits_only=True #160行

    二、网路训练

      (1)下载预训练模型

      下载地址:https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md  

      下载到deeplab目录下,然后解压:
      tar -zxvf deeplabv3_cityscapes_train_2018_02_06.tar.gz
      需要注意对应的解压文件目录为:
    /lwh/models/research/deeplab/deeplabv3_cityscapes_train

      (2)类别不平衡修正

        blackboard分割项目案例中的数据集,因为是3分类问题,其中background占了非常大的比例,设置的
        权重比例为1,3,3,
        注意:权重的设置对最终的分割性能有影响。权重的设置因数据集而异。    
        在common.py的145行修改权重如下:
      
    flags.DEFINE_multi_float(
        'label_weights', [1.0,3.0,3.0],
        'A list of label weights, each element represents the weight for the label '
        'of its index, for example, label_weights = [0.1, 0.5] means the weight '
        'for label 0 is 0.1 and the weight for label 1 is 0.5. If set as None, all '
        'the labels have the same weight 1.0.')

      (3)训练

        注意如下几个参数:
        train_logdir: 训练产生的文件存放位置
        dataset_dir: 数据集的TFRecord文件
        dataset:设置为在data_generator.py文件设置的数据集名称
        
        在自己的数据集上的训练指令如下:
        在目录 ~/models/research/deeplab下执行
      
    python train.py   --training_number_of_steps=30000  --train_split="train"  --model_variant="xception_65" 
    --atrous_rates=6 --atrous_rates=12 --atrous_rates=18 --output_stride=16 --decoder_output_stride=4
    --train_crop_size=801,801 --train_batch_size=2 --dataset="camvid"
    --tf_initial_checkpoint='/lwh/models/research/deeplab/deeplabv3_cityscapes_train/model.ckpt'
    --train_logdir='/lwh/models/research/deeplab/exp/blackboard_train/train'
    --dataset_dir='/lwh/models/research/deeplab/datasets/blackboard/tfrecord'

        设置train_crop_size原则:

        output_stride * k + 1, where k is an integer. For example, we have 321x321,513x513,801x801

      (4)模型导出

      

    python export_model.py 
        --logtostderr 
        --checkpoint_path="/lwh/models/research/deeplab/exp/blackboard_train/train/model.ckpt-30000" 
        --export_path="/lwh/models/research/deeplab/exp/blackboard_train/train/frozen_inference_graph.pb"  
        --model_variant="xception_65"  
        --atrous_rates=6  
        --atrous_rates=12  
        --atrous_rates=18   
        --output_stride=16  
        --decoder_output_stride=4  
        --num_classes=3 
        --crop_size=1080 
        --crop_size=1920 
        --inference_scales=1.0

      注意几点:

      --checkpoint_path 为自己模型保存的路径

      --export_path 模型导出保存的路径

      --num_classes=3 自己数据的类别数包含背景 

           --crop_size=1080  第一个为模型要求输入的高h

           --crop_size=1920 第一个为模型要求输入的宽w

    三、模型测试

      直接上代码

      

    # !--*-- coding:utf-8 --*--
    
    # Deeplab Demo
    
    import os
    import tarfile
    
    from matplotlib import gridspec
    import matplotlib.pyplot as plt
    import numpy as np
    from PIL import Image
    import tempfile
    from six.moves import urllib
    
    import tensorflow as tf
    
    
    class DeepLabModel(object):
        """
      加载 DeepLab 模型;
      推断 Inference
      """
        INPUT_TENSOR_NAME = 'ImageTensor:0'
        OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
        INPUT_SIZE = 1920
        FROZEN_GRAPH_NAME = 'frozen_inference_graph'
    
        def __init__(self, tarball_path):
            """
        Creates and loads pretrained deeplab model.
        """
            self.graph = tf.Graph()
    
            graph_def = None
            graph_def = tf.GraphDef.FromString(open(tarball_path, 'rb').read())
    
            if graph_def is None:
                raise RuntimeError('Cannot find inference graph in tar archive.')
    
            with self.graph.as_default():
                tf.import_graph_def(graph_def, name='')
            self.sess = tf.Session(graph=self.graph)
    
        def run(self, image):
            """
        Runs inference on a single image.
        Args:
        image: A PIL.Image object, raw input image.
        Returns:
        resized_image: RGB image resized from original input image.
        seg_map: Segmentation map of `resized_image`.
        """
            width, height = image.size
            resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
            target_size = (int(resize_ratio * width), int(resize_ratio * height))
            target_size = (1920,1080)
            resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
            print(resized_image)
            batch_seg_map = self.sess.run(self.OUTPUT_TENSOR_NAME,
                                          feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
            seg_map = batch_seg_map[0]
            return resized_image, seg_map
    
    
    def create_pascal_label_colormap():
        """
      Creates a label colormap used in PASCAL VOC segmentation benchmark.
      Returns:
          A Colormap for visualizing segmentation results.
      """
        colormap = np.zeros((256, 3), dtype=int)
        ind = np.arange(256, dtype=int)
    
        for shift in reversed(range(8)):
            for channel in range(3):
                colormap[:, channel] |= ((ind >> channel) & 1) << shift
            ind >>= 3
    
        return colormap
    
    
    def label_to_color_image(label):
        """
      Adds color defined by the dataset colormap to the label.
      Args:
          label: A 2D array with integer type, storing the segmentation label.
      Returns:
          result: A 2D array with floating type. The element of the array
          is the color indexed by the corresponding element in the input label
          to the PASCAL color map.
      Raises:
          ValueError: If label is not of rank 2 or its value is larger than color
          map maximum entry.
      """
        if label.ndim != 2:
            raise ValueError('Expect 2-D input label')
    
        colormap = create_pascal_label_colormap()
    
        if np.max(label) >= len(colormap):
            raise ValueError('label value too large.')
    
        return colormap[label]
    
    
    def vis_segmentation(image, seg_map):
        """Visualizes input image, segmentation map and overlay view."""
        plt.figure(figsize=(15, 5))
        grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
    
        plt.subplot(grid_spec[0])
        plt.imshow(image)
        plt.axis('off')
        plt.title('input image')
    
        plt.subplot(grid_spec[1])
        seg_image = label_to_color_image(seg_map).astype(np.uint8)
        plt.imshow(seg_image)
        plt.axis('off')
        plt.title('segmentation map')
    
        plt.subplot(grid_spec[2])
        plt.imshow(image)
        plt.imshow(seg_image, alpha=0.7)
        plt.axis('off')
        plt.title('segmentation overlay')
    
        unique_labels = np.unique(seg_map)
        ax = plt.subplot(grid_spec[3])
        plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
        ax.yaxis.tick_right()
        plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
        plt.xticks([], [])
        ax.tick_params(width=0.0)
        plt.grid('off')
        plt.show()
    
    LABEL_NAMES = np.asarray(
        ['background', 'blackboard','screen'])
    # LABEL_NAMES = np.asarray(
    #     ['background', 'blackboard','screen'])
    
    FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
    FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
    
    
    
    download_path =  r"D:python_projectdeeplabv3+lackboard_v2.pb"
    
    MODEL = DeepLabModel(download_path)
    print('model loaded successfully!')
    
    
    ##
    def run_visualization(imagefile):
        """
      DeepLab 语义分割,并可视化结果.
      """
        orignal_im = Image.open(imagefile)
        print('running deeplab on image %s...' % imagefile)
        resized_im, seg_map = MODEL.run(orignal_im)
        print(seg_map.shape)
    
        vis_segmentation(resized_im, seg_map)
    
    
    images_dir = r'D:python_projectdeeplabv3+	est_img'  # 测试图片目录所在位置
    images = sorted(os.listdir(images_dir))
    for imgfile in images:
        run_visualization(os.path.join(images_dir, imgfile))
    
    print('Done.')

      需要注意的两点:

      1.images_dir 修改为自己存图片的dir

      2.INPUT_SIZE = 1920修改自己图片的hw最大的一个

      测试结果展示

  • 相关阅读:
    GridView多表头固定+分组+总计
    ajaxpro.2.dll使用【转帖】
    表达式计算易错题
    uclibc下使用libcurl的段错误(缺少hosts文件)
    《java.util.concurrent 包源码阅读》01 源码包的结构
    Linux学习笔记【2】Install Software under RedHat enterprise 5.4
    Windows语言包的那些事
    Let outlook work background when it is minimal
    DB2 Error Message
    db2 系统表信息
  • 原文地址:https://www.cnblogs.com/liuwenhua/p/15136718.html
Copyright © 2011-2022 走看看