zoukankan      html  css  js  c++  java
  • 将VOC2012转换为tfrecord

    PASCAL-VOC2012简介

    PASCAL-VOC2012数据集介绍官网:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html ,数据集下载地址:benchmark_RELEASE:下载地址 voc2012:下载地址

    VOC2012数据集分为20类,包括背景为21类,分别如下: 

    • Person: person 
    • Animal: bird, cat, cow, dog, horse, sheep 
    • Vehicle: aeroplane, bicycle, boat, bus, car, motorbike, train 
    • Indoor: bottle, chair, dining table, potted plant, sofa, tv/monitor

    再看一下VOC2012数据集里有哪些文件夹:

    在目标检测中,主要用到了 Annotations,ImageSets,JPEGImages,其中 ImageSets/Main/ 保存了具体数据集的索引,Annotations 保存了标签数据, JPEGImages 保存了图片内容。

    ImageSets/Main/ 文件夹以 , {class}_trainval.txt {class}_val.txt 的格式命名。 train.txt val.txt 例外,包括 Action,Layout,Main,Segmentation 四个文件夹:

    • Action:存放的是人的动作(例如running、jumping等等,这也是VOC challenge的一部分)
    • Layout:存放的是具有人体部位的数据(人的head、hand、feet等等,这也是VOC challenge的一部分
    • Main:存放的是图像物体识别的数据,总共分为20类。
    • Segmentation:存放的是可用于分割的数据。

    在图像分割中,主要使用了SegmentationClass,SegmentationObject,JPEGImages有关的信息,VOC2012中的图片并不是都用于分割,用于分割比赛的图片实例如下,包含原图以及图像分类分割和图像物体分割两种png图。图像分类分割是在20种物体中,ground-turth图片上每个物体的轮廓填充都有一个特定的颜色,一共20种颜色,比如摩托车用红色表示,人用绿色表示。而图像物体分割则仅仅在一副图中生成不同物体的轮廓颜色即可,颜色自己随便填充。

     ImageSets/Main/ 文件夹以 , {class}_trainval.txt {class}_val.txt 的格式命名。 train.txt val.txt 例外

    aeroplane_train.txt
    aeroplane_trainval.txt
    aeroplane_val.txt
    bicycle_train.txt
    bicycle_trainval.txt
    bicycle_val.txt
    bird_train.txt
    bird_trainval.txt
    bird_val.txt
    boat_train.txt
    boat_trainval.txt
    boat_val.txt
    bottle_train.txt
    bottle_trainval.txt
    bottle_val.txt
    bus_train.txt
    bus_trainval.txt
    bus_val.txt
    car_train.txt
    car_trainval.txt
    car_val.txt
    cat_train.txt
    cat_trainval.txt
    cat_val.txt
    chair_train.txt
    chair_trainval.txt
    chair_val.txt
    cow_train.txt
    cow_trainval.txt
    cow_val.txt
    diningtable_train.txt
    diningtable_trainval.txt
    diningtable_val.txt
    dog_train.txt
    dog_trainval.txt
    dog_val.txt
    horse_train.txt
    horse_trainval.txt
    horse_val.txt
    motorbike_train.txt
    motorbike_trainval.txt
    motorbike_val.txt
    person_train.txt
    person_trainval.txt
    person_val.txt
    pottedplant_train.txt
    pottedplant_trainval.txt
    pottedplant_val.txt
    sheep_train.txt
    sheep_trainval.txt
    sheep_val.txt
    sofa_train.txt
    sofa_trainval.txt
    sofa_val.txt
    train.txt
    train_train.txt
    train_trainval.txt
    train_val.txt
    trainval.txt
    tvmonitor_train.txt
    tvmonitor_trainval.txt
    tvmonitor_val.txt
    val.txt

    • {class}_train.txt 保存类别为 class 的训练集的所有索引,每一个 class 的 train 数据都有 5717 个。
    • {class}_val.txt 保存类别为 class 的验证集的所有索引,每一个 class 的val数据都有 5823 个
    • {class}_trainval.txt 保存类别为 class 的训练验证集的所有索引,每一个 class 的val数据都有11540 个

    每个文件包含内容为:

    2011_003194 -1
    2011_003216 -1
    2011_003223 -1
    2011_003230 1
    2011_003236 1
    2011_003238 1
    2011_003246 1
    2011_003247 0
    2011_003253 -1
    2011_003255 1
    2011_003259 1
    2011_003274 -1
    2011_003276 -1

    注:1代表正样本,-1代表负样本。

    VOC2012/ImageSets/Main/train.txt 保存了所有训练集的文件名,从 VOC2012/JPEGImages/ 找到文件名对应的图片文件。VOC2012/Annotations/ 找到文件名对应的标签文件

    VOC2012/ImageSets/Main/val.txt 保存了所有验证集的文件名,从 VOC2012/JPEGImages/ 找到文件名对应的图片文件。VOC2012/Annotations/ 找到文件名对应的标签文件

    读取 JPEGImages 和 Annotation 文件转换为 tf 的 Example 对象,写入 {train|test}{index}_of{num_shard} 文件。每个文件写的 Example 的数量为 total_size/num_shard。(不同数据集可以适当调节 num_shard 来控制每个输出文件的大小)

    Annotations

    文件夹中文件以 {id}.xml (id 保存在 VOC2012/ImageSets/Main/文件夹 ) 格式命名的 xml 文件,保存如下关键信息

    • 物体 label : name ,如下例子为 person
    • 图片尺寸: depth, height, width
    • 物体 bbox : bndbox 下 xmax, xmin, ymax, ymin
    <annotation>
    	<folder>VOC2012</folder>
    	<filename>2007_000032.jpg</filename>
    	<source>
    		<database>The VOC2007 Database</database>
    		<annotation>PASCAL VOC2007</annotation>
    		<image>flickr</image>
    	</source>
    	<size>
    		<width>500</width>
    		<height>281</height>
    		<depth>3</depth>
    	</size>
    	<segmented>1</segmented>
    	<object>
    		<name>aeroplane</name>
    		<pose>Frontal</pose>
    		<truncated>0</truncated>
    		<difficult>0</difficult>
    		<bndbox>
    			<xmin>104</xmin>
    			<ymin>78</ymin>
    			<xmax>375</xmax>
    			<ymax>183</ymax>
    		</bndbox>
    	</object>
    	<object>
    		<name>aeroplane</name>
    		<pose>Left</pose>
    		<truncated>0</truncated>
    		<difficult>0</difficult>
    		<bndbox>
    			<xmin>133</xmin>
    			<ymin>88</ymin>
    			<xmax>197</xmax>
    			<ymax>123</ymax>
    		</bndbox>
    	</object>
    	<object>
    		<name>person</name>
    		<pose>Rear</pose>
    		<truncated>0</truncated>
    		<difficult>0</difficult>
    		<bndbox>
    			<xmin>195</xmin>
    			<ymin>180</ymin>
    			<xmax>213</xmax>
    			<ymax>229</ymax>
    		</bndbox>
    	</object>
    	<object>
    		<name>person</name>
    		<pose>Rear</pose>
    		<truncated>0</truncated>
    		<difficult>0</difficult>
    		<bndbox>
    			<xmin>26</xmin>
    			<ymin>189</ymin>
    			<xmax>44</xmax>
    			<ymax>238</ymax>
    		</bndbox>
    	</object>
    </annotation>

    tfrecord格式简介

    tfrecord是Tensorflow官方推荐的一种较为高效的数据读取方式。使用Tensorflow训练神经网络时,读取的数据方式有很多种。如果数据集比较小,而且内存足够大,可以选择直接将所有数据读进内存,然后每次取一个batch的数据出来。如果数据较多,可以每次直接从硬盘中进行读取,不过这种方式的读取效率就比较低了。
    tfrecord其实是一种数据存储形式。使用tfrecord时,实际上是先读取原生数据,然后转换成tfrecord格式,再存储在硬盘上。而使用时,再把数据从相应的tfrecord文件中解码读取出来。

    Tensorflow有和tfrecord配套的一些函数,可以加快数据的处理。实际读取tfrecord数据时,先以相应的tfrecord文件为参数,创建一个输入队列,这个队列有一定的容量,用户可以设置不同的值,在一部分数据出队列时,tfrecord中的其他数据就可以通过预取进入队列,并且这个过程和网络的计算是独立进行的。也就是说,网络每一个iteration的训练不必等待数据队列准备好再开始,队列中的数据始终是充足的,而往队列中填充数据时,也可以使用多线程加速。

    tfecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Example的定义。

    message Example {
      Features features = 1;
    };
    
    message Features{
      map<string,Feature> featrue = 1;
    };
    
    message Feature{
      oneof kind{
            BytesList bytes_list = 1;
            FloatList float_list = 2;
            Int64List int64_list = 3;
        }
    };
    

    tf.train.Example中包含了属性名称到取值的字典,其中属性名称为字符串,属性的取值可以为字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)。

    将数据保存为tfrecord格式

    首先,创建以tfrecord为后缀的文件名

    tfrecords_filename = './tfrecords/train.tfrecords'
    writer = tf.python_io.TFRecordWriter(tfrecords_filename) # 创建.tfrecord文件,准备写入

    然后创建一个循环一次写入数据

        for i in range(100):
            img_raw = np.random.random_integers(0,255,size=(7,30)) # 创建7*30,取值在0-255之间随机数组
            img_raw = img_raw.tostring()
            example = tf.train.Example(features=tf.train.Features(
                    feature={
                    'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),     
                    'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
                    }))
            writer.write(example.SerializeToString()) 
        
        writer.close()

    example = tf.train.Example()这句将数据赋给了变量example(可以看到里面是通过字典结构实现的赋值),然后用writer.write(example.SerializeToString()) 这句实现写入。

    值得注意的是赋值给example的数据格式。从前面tf.train.Example的定义可知,tfrecord支持整型、浮点数和二进制三种格式,分别是

    tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))
    tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))
    tf.train.Feature(bytes_list = tf.train.FloatList(value=[float_scalar]))

    例如图片等数组形式(array)的数据,可以保存为numpy array的格式,转换为string,然后保存到二进制格式的feature中。对于单个的数值(scalar),可以直接赋值。这里value=[×]的[]非常重要,也就是说输入的必须是列表(list)。当然,对于输入数据是向量形式的,可以根据数据类型(float还是int)分别保存。并且在保存的时候还可以指定数据的维数。

    读取tfrecord数据

    tf.parse_single_example解码,tf.TFRecordReader读取,一般,为了高效的读取数据,tf中使用队列读取数据

    def read_and_decode(filename):
        # 生成一个文件名的队列
        filename_queue = tf.train.string_input_producer([filename])
        reader = tf.TFRecordReader()  # 定义一个reader
        _, serialized_example = reader.read(filename_queue)   # 读取文件名和example
    
        # 还原feature, 和制作tfrecords时一样
        feature = { 'label': tf.FixedLenFeature([], tf.int64),  # 对于单个元素的变量,我们使用FixlenFeature来读取,需要指明变量存储的数据类型;对于list类型的变量,我们使用VarLenFeature来读取,同样需要指明读取变量的类型
                    'img_raw' : tf.FixedLenFeature([], tf.string), }
        # 使用tf.parse_single_example来解析example
        features = tf.parse_single_example(serialized_example, features=feature)
    
        # 对于图像,使用tf.decode_raw解析对应的features,指定类型,然后reshape等
        img = tf.decode_raw(features['img_raw'], tf.uint8)
        img = tf.reshape(img, [224, 224, 3])
        img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
        label = tf.cast(features['label'], tf.int32)
    
        return img, label
    
    img, label = read_and_decode('train.tfrecords')
    # 在训练时使用shuffle_batch随机打乱顺序,并生成batch
    img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                    batch_size=30, 
                                                    capacity=2000,  # 队列的最大容量
                                                    num_threads=1,  # 进行队列操作的线程数
                                                    min_after_dequeue=1000) # dequeue后最小的队列大小,used to ensure a level of mixing of elements.
    
    # tf队列也需要初始化在sess中才能执行                      
    init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
    with tf.Session() as sess:
        sess.run(init_op)
    
        coord = tf.train.Coordinator()  # 创建一个coordinate,用于协调各线程
        threads = tf.train.start_queue_runners(coord=coord)  # 使用QueueRunner对象来提取数据
    
        try:  # 推荐代码
            while not coord.should_stop():
                # Run training steps or whatever
                sess.run(train_op)
        except tf.errors.OutOfRangeError:
            print 'Done training -- epoch limit reached'
        finally:
            # When done, ask the threads to stop.关闭线程
            coord.request_stop()
    
        # Wait for threads to finish.
        coord.join(threads)

    以目标检测所使用的文件为例,制作tfrecord文件代码如下:

    # coding=utf-8
    import os
    import sys
    import random
    
    import numpy as np
    import tensorflow as tf
    # process a xml file
    import xml.etree.ElementTree as ET
    
    DIRECTORY_ANNOTATIONS = 'Annotations/'
    DIRECTORY_IMAGES = 'JPEGImages/'
    RANDOM_SEED = 4242
    SAMPLES_PER_FILES = 20000
    
    VOC_LABELS = {
        'none': (0, 'Background'),
        'aeroplane': (1, 'Vehicle'),
        'bicycle': (2, 'Vehicle'),
        'bird': (3, 'Animal'),
        'boat': (4, 'Vehicle'),
        'bottle': (5, 'Indoor'),
        'bus': (6, 'Vehicle'),
        'car': (7, 'Vehicle'),
        'cat': (8, 'Animal'),
        'chair': (9, 'Indoor'),
        'cow': (10, 'Animal'),
        'diningtable': (11, 'Indoor'),
        'dog': (12, 'Animal'),
        'horse': (13, 'Animal'),
        'motorbike': (14, 'Vehicle'),
        'person': (15, 'Person'),
        'pottedplant': (16, 'Indoor'),
        'sheep': (17, 'Animal'),
        'sofa': (18, 'Indoor'),
        'train': (19, 'Vehicle'),
        'tvmonitor': (20, 'Indoor'),
    }
    
    
    #返回一个int64_list
    def int64_feature(values):
        """Returns a TF-Feature of int64s.
        Args:
        values: A scalar or list of values.
        Returns:
        a TF-Feature.
        """
        if not isinstance(values, (tuple, list)):
            values = [values]
        return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
    
    #返回float_list
    def float_feature(value):
        """Wrapper for inserting float features into Example proto.
        """
        if not isinstance(value, list):
            value = [value]
        return tf.train.Feature(float_list=tf.train.FloatList(value=value))
    #返回bytes_list
    def bytes_feature(value):
        """Wrapper for inserting bytes features into Example proto.
        """
        if not isinstance(value, list):
            value = [value]
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
    
    #split的三种类型
    SPLIT_MAP = ['train', 'val', 'trainval']
    
    """
    Process a image and annotation file.
    Args:
        filename:       string, path to an image file e.g., '/path/to/example.JPG'.
        coder:          instance of ImageCoder to provide TensorFlow image coding utils.
    Returns:
        image_buffer:   string, JPEG encoding of RGB image.
        height:         integer, image height in pixels.
                  integer, image width in pixels.
    读取一个样本图片及对应信息
    directory:图片所在路径,name:图片名称
    """
    def _process_image(directory, name):
        # Read the image file.
        filename = os.path.join(directory, DIRECTORY_IMAGES, name + '.jpg')
        image_data = tf.gfile.FastGFile(filename, 'rb').read()  #使用gfile读取图片
        # Read the XML annotation file.
        filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')
        tree = ET.parse(filename)   #XML文档表示为树,ElementTree
        root = tree.getroot()       #树的根节点
        # Image shape.
        size = root.find('size')
        shape = [int(size.find('height').text), int(size.find('width').text), int(size.find('depth').text)]
        # Find annotations.
        # 获取每个object的信息
        bboxes = []
        labels = []
        labels_text = []
        difficult = []
        truncated = []
        for obj in root.findall('object'):
            label = obj.find('name').text
            labels.append(int(VOC_LABELS[label][0]))
            labels_text.append(label.encode('ascii'))
    
            if obj.find('difficult'):
                difficult.append(int(obj.find('difficult').text))
            else:
                difficult.append(0)
            if obj.find('truncated'):
                truncated.append(int(obj.find('truncated').text))
            else:
                truncated.append(0)
    
            bbox = obj.find('bndbox')
            bboxes.append((float(bbox.find('ymin').text) / shape[0],
                           float(bbox.find('xmin').text) / shape[1],
                           float(bbox.find('ymax').text) / shape[0],
                           float(bbox.find('xmax').text) / shape[1]
                           ))
        return image_data, shape, bboxes, labels, labels_text, difficult, truncated
    
    """
    Build an Example proto for an image example.
    Args:
      image_data: string, JPEG encoding of RGB image;
      labels: list of integers, identifier for the ground truth;
      labels_text: list of strings, human-readable labels;
      bboxes: list of bounding boxes; each box is a list of integers;
          specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong
          to the same label as the image label.
      shape: 3 integers, image shapes in pixels.
    Returns:
      Example proto
    将一个图片及对应信息按格式转换成训练时可读取的一个样本
    """
    def _convert_to_example(image_data, labels, labels_text, bboxes, shape, difficult, truncated):
        xmin = []
        ymin = []
        xmax = []
        ymax = []
        for b in bboxes:
            assert len(b) == 4
            # pylint: disable=expression-not-assigned
            [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
            # pylint: enable=expression-not-assigned
    
        image_format = b'JPEG'
        example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': int64_feature(shape[0]),
            'image/width': int64_feature(shape[1]),
            'image/channels': int64_feature(shape[2]),
            'image/shape': int64_feature(shape),
            'image/object/bbox/xmin': float_feature(xmin),
            'image/object/bbox/xmax': float_feature(xmax),
            'image/object/bbox/ymin': float_feature(ymin),
            'image/object/bbox/ymax': float_feature(ymax),
            'image/object/bbox/label': int64_feature(labels),
            'image/object/bbox/label_text': bytes_feature(labels_text),
            'image/object/bbox/difficult': int64_feature(difficult),
            'image/object/bbox/truncated': int64_feature(truncated),
            'image/format': bytes_feature(image_format),
            'image/encoded': bytes_feature(image_data)}))
        return example
    
    
    """
    Loads data from image and annotations files and add them to a TFRecord.
    Args:
      dataset_dir: Dataset directory;
      name: Image name to add to the TFRecord;
      tfrecord_writer: The TFRecord writer to use for writing.
    """
    def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):
        image_data, shape, bboxes, labels, labels_text, difficult, truncated = 
            _process_image(dataset_dir, name)
        example = _convert_to_example(image_data,
                                      labels,
                                      labels_text,
                                      bboxes,
                                      shape,
                                      difficult,
                                      truncated)
        tfrecord_writer.write(example.SerializeToString())
    
    
    """
    以VOC2012为例,下载后的文件名为:VOCtrainval_11-May-2012.tar,解压后
    得到一个文件夹:VOCdevkit
    voc_root就是VOCdevkit文件夹所在的路径
    在VOCdevkit文件夹下只有一个文件夹:VOC2012,所以下边参数year该文件夹的数字部分。
    在VOCdevkit/VOC2012/ImageSets/Main下存放了20个类别,每个类别有3个的txt文件:
    *.train.txt存放训练使用的数据
    *.val.txt存放测试使用的数据
    *.trainval.txt是train和val的合集
    所以参数split只能为'train', 'val', 'trainval'之一
    """
    def run(voc_root, year, split, output_dir, shuffling=False):
        # 如果output_dir不存在则创建
        if not tf.gfile.Exists(output_dir):
            tf.gfile.MakeDirs(output_dir)
        # VOCdevkit/VOC2012/ImageSets/Main/train.txt
        # 中存放有所有20个类别的训练样本名称,共5717个
        split_file_path = os.path.join(voc_root, 'VOC%s' % year, 'ImageSets', 'Main', '%s.txt' % split)
        print('>> ', split_file_path)
        with open(split_file_path) as f:
            filenames = f.readlines()
        # shuffling == Ture时,打乱顺序
        if shuffling:
            random.seed(RANDOM_SEED)
            random.shuffle(filenames)
        # Process dataset files.
        i = 0
        fidx = 0
        dataset_dir = os.path.join(voc_root, 'VOC%s' % year)
        while i < len(filenames):
            # Open new TFRecord file.
            tf_filename = '%s/%s_%03d.tfrecord' % (output_dir, split, fidx)
            with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
                j = 0
                while i < len(filenames) and j < SAMPLES_PER_FILES:
                    sys.stdout.write('
    >> Converting image %d/%d' % (i + 1, len(filenames)))
                    sys.stdout.flush()
                    filename = filenames[i].strip()
                    _add_to_tfrecord(dataset_dir, filename, tfrecord_writer)
                    i += 1
                    j += 1
                fidx += 1
        print('
    >> Finished converting the Pascal VOC dataset!')
    
    if __name__ == '__main__':
        # if len(sys.argv) < 2:
        #     raise ValueError('>> error. format: python *.py split_name')
        split = 'train'     #'train|val|trainval'
        if split not in SPLIT_MAP:
            raise ValueError('>> error. split = %s' % split)
        voc_root = 'E:/data/VOCdevkit/'
        run(voc_root, 2012, split,voc_root)

    以图像分割为例,代码如下

    代码中所需的build_data.py,点击打开

    # Copyright 2018 The TensorFlow Authors All Rights Reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    # ==============================================================================
    
    """Converts PASCAL VOC 2012 data to TFRecord file format with Example protos.
    PASCAL VOC 2012 dataset is expected to have the following directory structure:
      + pascal_voc_seg
        - build_data.py
        - build_voc2012_data.py (current working directory).
        + VOCdevkit
          + VOC2012
            + JPEGImages
            + SegmentationClass
            + ImageSets
              + Segmentation
        + tfrecord
    Image folder:
      ./VOCdevkit/VOC2012/JPEGImages
    Semantic segmentation annotations:
      ./VOCdevkit/VOC2012/SegmentationClass
    list folder:
      ./VOCdevkit/VOC2012/ImageSets/Segmentation
    This script converts data into sharded data files and save at tfrecord folder.
    The Example proto contains the following fields:
      image/encoded: encoded image content.
      image/filename: image filename.
      image/format: image file format.
      image/height: image height.
      image/ image width.
      image/channels: image channels.
      image/segmentation/class/encoded: encoded semantic segmentation content.
      image/segmentation/class/format: semantic segmentation file format.
    """
    import math
    import os.path
    import sys
    import build_data 
    import tensorflow as tf
    
    FLAGS = tf.app.flags.FLAGS
    
    tf.app.flags.DEFINE_string('image_folder',
                               './pascal_voc_seg/VOCdevkit/VOC2012/JPEGImages',
                               'Folder containing images.')
    
    tf.app.flags.DEFINE_string(
        'semantic_segmentation_folder',
        './pascal_voc_seg/VOCdevkit/VOC2012/SegmentationClassRaw',
        'Folder containing semantic segmentation annotations.')
    #train.txt,val.txt,trainval.txt
    tf.app.flags.DEFINE_string(
        'list_folder',
        './pascal_voc_seg/VOCdevkit/VOC2012/ImageSets/Segmentation',
        'Folder containing lists for training and validation')
    
    #tfrecord输出路径
    tf.app.flags.DEFINE_string(
        'output_dir',
        './pascal_voc_seg/tfrecord',
        'Path to save converted SSTable of TensorFlow examples.')
    
    _NUM_SHARDS = 4
    
    
    def _convert_dataset(dataset_split):
        """Converts the specified dataset split to TFRecord format.
        Args:
          dataset_split: The dataset split (e.g., train, test).
        Raises:
          RuntimeError: If loaded image and label have different shape.
        """
        dataset = os.path.basename(dataset_split)[:-4]
        sys.stdout.write('Processing ' + dataset)
        filenames = [x.strip('
    ') for x in open(dataset_split, 'r')]
        num_images = len(filenames)
        num_per_shard = int(math.ceil(num_images / float(_NUM_SHARDS)))
    
        image_reader = build_data.ImageReader('jpg', channels=3)
        label_reader = build_data.ImageReader('png', channels=1)
    
        for shard_id in range(_NUM_SHARDS):
            output_filename = os.path.join(
                FLAGS.output_dir,
                '%s-%05d-of-%05d.tfrecord' % (dataset, shard_id, _NUM_SHARDS))
            with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
                start_idx = shard_id * num_per_shard
                end_idx = min((shard_id + 1) * num_per_shard, num_images)
                for i in range(start_idx, end_idx):
                    sys.stdout.write('
    >> Converting image %d/%d shard %d' % (
                        i + 1, len(filenames), shard_id))
                    sys.stdout.flush()
                    # Read the image.
                    image_filename = os.path.join(
                        FLAGS.image_folder, filenames[i] + '.jpg' )#+ FLAGS.image_format)
                    image_data = tf.gfile.FastGFile(image_filename, 'rb').read()
                    height, width = image_reader.read_image_dims(image_data)
                    # Read the semantic segmentation annotation.
                    seg_filename = os.path.join(
                        FLAGS.semantic_segmentation_folder,
                        filenames[i] + '.' + FLAGS.label_format)
                    seg_data = tf.gfile.FastGFile(seg_filename, 'rb').read()
                    seg_height, seg_width = label_reader.read_image_dims(seg_data)
                    if height != seg_height or width != seg_
                        raise RuntimeError('Shape mismatched between image and label.')
                    # Convert to tf example.
                    example = build_data.image_seg_to_tfexample(
                        image_data, filenames[i], height, width, seg_data)
                    tfrecord_writer.write(example.SerializeToString())
            sys.stdout.write('
    ')
            sys.stdout.flush()
    
    
    def main(unused_argv):
        dataset_splits = tf.gfile.Glob(os.path.join(FLAGS.list_folder, '*.txt'))
        for dataset_split in dataset_splits:
            _convert_dataset(dataset_split)
    
    
    if __name__ == '__main__':
        tf.app.run()

    参考链接一

    参考链接二

    参考链接三

    参考链接

    天上我才必有用,千金散尽还复来!
  • 相关阅读:
    Android官方架构组件介绍之ViewModel
    Android官方架构组件介绍之LiveData
    Android官方架构组件介绍之LifeCycle
    Android N 通知概览及example
    Project和Task
    hello gradle
    写出gradle风格的groovy代码
    Groovy中的面向对象
    tcp_tw_recycle和tcp_timestamps的一些知识(转)
    Xtrabackup 热备
  • 原文地址:https://www.cnblogs.com/lutaishi/p/13436218.html
Copyright © 2011-2022 走看看