zoukankan      html  css  js  c++  java
  • 自己写的制作 city的语义分割tfrecord 适用于deeplabv3+

    自己写的制作 city的语义分割tfrecord  适用于deeplabv3+

    自用

    """Converts PASCAL dataset to TFRecords file format."""
    
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import argparse
    import io
    import os
    import sys
    import natsort
    import PIL.Image
    import tensorflow as tf
    
    from utils import dataset_util
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--data_dir', type=str, default='/home/a/dataset/cityscapes/',
                        help='Path to the directory containing the PASCAL VOC data.')
    
    parser.add_argument('--output_path', type=str, default='./dataset',
                        help='Path to the directory to create TFRecords outputs.')
    
    parser.add_argument('--train_data_list', type=str, default='./dataset/train.txt',
                        help='Path to the file listing the training data.')
    
    parser.add_argument('--valid_data_list', type=str, default='./dataset/val.txt',
                        help='Path to the file listing the validation data.')
    
    parser.add_argument('--image_data_dir', type=str, default='leftImg8bit',
                        help='The directory containing the image data.')
    
    parser.add_argument('--label_data_dir', type=str, default='gtFine',
                        help='The directory containing the augmented label data.')
    
    
    def dict_to_tf_example(image_path,
                           label_path):
      """Convert image and label to tf.Example proto.
    
      Args:
        image_path: Path to a single PASCAL image.
        label_path: Path to its corresponding label.
    
      Returns:
        example: The converted tf.Example.
    
      Raises:
        ValueError: if the image pointed to by image_path is not a valid JPEG or
                    if the label pointed to by label_path is not a valid PNG or
                    if the size of image does not match with that of label.
      """
      with tf.gfile.GFile(image_path, 'rb') as fid:
        encoded_jpg = fid.read()
      encoded_jpg_io = io.BytesIO(encoded_jpg)
      image = PIL.Image.open(encoded_jpg_io)
      if image.format != 'PNG':
        raise ValueError('Image format not PNG')
    
      with tf.gfile.GFile(label_path, 'rb') as fid:
        encoded_label = fid.read()
      encoded_label_io = io.BytesIO(encoded_label)
      label = PIL.Image.open(encoded_label_io)
      if label.format != 'PNG':
        raise ValueError('Label format not PNG')
    
      if image.size != label.size:
        raise ValueError('The size of image does not match with that of label.')
    
      width, height = image.size
    
      example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature('png'.encode('utf8')),
        'label/encoded': dataset_util.bytes_feature(encoded_label),
        'label/format': dataset_util.bytes_feature('png'.encode('utf8')),
      }))
      return example
    def scanDir_img_File(dir):
        for root, dirs, files in os.walk(dir, True, None, False):  # 遍列目录
            for f in files:
                yield os.path.join(root,f)
    
    def scanDir_lable_File(dir):
        for root, dirs, files in os.walk(dir, True, None, False):  # 遍列目录
            # 处理该文件夹下所有文件:
    
            for f in files:
                if os.path.isfile(os.path.join(root, f)):
                    a = os.path.splitext(f)
                    lable = a[0].split('_')[4]
                    # print(lable)
                    if lable in ('labelTrainIds'):
                        # print(os.path.join(root,f))
                        yield os.path.join(root,f)
    
    def create_tf_record(output_filename,
                         image_dir,
                         label_dir):
      """Creates a TFRecord file from examples.
    
      Args:
        output_filename: Path to where output file is saved.
        image_dir: Directory where image files are stored.
        label_dir: Directory where label files are stored.
      """
      imgg = []
      writer = tf.python_io.TFRecordWriter(output_filename)
    
      img = scanDir_img_File(image_dir)
      for imgs in img:
        imgg.append(imgs)
      image_list = natsort.natsorted(imgg)
    
      lable = scanDir_lable_File(label_dir)
      lablee = []
      for lables in lable:
        lablee.append(lables)
      label_list = natsort.natsorted(lablee)
      for image_path,label_path in zip(image_list,label_list):
        print(image_path,label_path)
        try:
          tf_example = dict_to_tf_example(image_path, label_path)
          writer.write(tf_example.SerializeToString())
        except ValueError:
          tf.logging.warning('Invalid example: %s, ignoring.')
    
      writer.close()
    
    
    def main(unused_argv):
      if not os.path.exists(FLAGS.output_path):
        os.makedirs(FLAGS.output_path)
    
      tf.logging.info("Reading from CITY dataset")
      train_image_dir = os.path.join(FLAGS.data_dir, FLAGS.image_data_dir,'train')
      train_label_dir = os.path.join(FLAGS.data_dir, FLAGS.label_data_dir,'train')
      val_image_dir = os.path.join(FLAGS.data_dir, FLAGS.image_data_dir, 'val')
      val_label_dir = os.path.join(FLAGS.data_dir, FLAGS.label_data_dir, 'val')
    
      train_output_path = os.path.join(FLAGS.output_path, 'city_train.record')
      val_output_path = os.path.join(FLAGS.output_path, 'city_val.record')
    
      create_tf_record(train_output_path, train_image_dir, train_label_dir)
      create_tf_record(val_output_path, val_image_dir, val_label_dir)
    
    
    if __name__ == '__main__':
      tf.logging.set_verbosity(tf.logging.INFO)
      FLAGS, unparsed = parser.parse_known_args()
      tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
     
  • 相关阅读:
    grunt in webstorm
    10+ Best Responsive HTML5 AngularJS Templates
    响应式布局
    responsive grid
    responsive layout
    js event bubble and capturing
    Understanding Service Types
    To add private variable to this Javascript literal object
    Centering HTML elements larger than their parents
    java5 新特性
  • 原文地址:https://www.cnblogs.com/ansang/p/8631857.html
Copyright © 2011-2022 走看看