zoukankan      html  css  js  c++  java
  • 生成TFRecord文件完整代码实例

    import os
    import json
    
    
    def get_annotation_dict(input_folder_path, word2number_dict):
        label_dict = {}
        father_file_list = os.listdir(input_folder_path)
        for father_file in father_file_list:
            full_father_file = os.path.join(input_folder_path, father_file)
            son_file_list = os.listdir(full_father_file)
            for image_name in son_file_list:
                label_dict[os.path.join(full_father_file, image_name)] = word2number_dict[father_file]
        return label_dict
    
    
    def save_json(label_dict, json_path):
        with open(json_path, 'w') as json_path:
            json.dump(label_dict, json_path)
        print("label json file has been generated successfully!")
    1. generate_annotation_json.py

    总共有七种分类图片,类别的名称就是每个文件夹名称

    generate_annotation_json.py是为了得到图片标注的label_dict。通过这个代码块可以获得我们需要的图片标注字典,key是图片具体地址, value是图片的类别,具体实例如下:
    {
    "/images/hangs/862e67a8-5bd9-41f1-8c6d-876a3cb270df.JPG": 6, 
    "/images/tags/adc264af-a76b-4477-9573-ac6c435decab.JPG": 3, 
    "/images/tags/fd231f5a-b42c-43ba-9e9d-4abfbaf38853.JPG": 3, 
    "/images/hangs/2e47d877-1954-40d6-bfa2-1b8e3952ebf9.jpg": 6, 
    "/images/tileds/a07beddc-4b39-4865-8ee2-017e6c257e92.png": 5,
     "/images/models/642015c8-f29d-4930-b1a9-564f858c40e5.png": 4
    }
    1. generate_tfrecord.py

    import os
    import tensorflow as tf
    import io
    from PIL import Image
    from generate_annotation_json import get_annotation_dict

    flags = tf.app.flags
    flags.DEFINE_string('images_dir',
    '/data2/raycloud/jingxiong_datasets/six_classes/images',
    'Path to image(directory)')
    flags.DEFINE_string('annotation_path',
    '/data1/humaoc_file/classify/data/annotations/annotations.json',
    'Path to annotation')
    flags.DEFINE_string('record_path',
    '/data1/humaoc_file/classify/data/train_tfrecord/train.record',
    'Path to TFRecord')
    FLAGS = flags.FLAGS


    def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


    def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


    def process_image_channels(image):
    process_flag = False
    # process the 4 channels .png
    if image.mode == 'RGBA':
    r, g, b, a = image.split()
    image = Image.merge("RGB", (r,g,b))
    process_flag = True
    # process the channel image
    elif image.mode != 'RGB':
    image = image.convert("RGB")
    process_flag = True
    return image, process_flag


    def process_image_reshape(image, resize):
    width, height = image.size
    if resize is not None:
    if width > height:
    width = int(width * resize / height)
    height = resize
    else:
    width = resize
    height = int(height * resize / width)
    image = image.resize((width, height), Image.ANTIALIAS)
    return image


    def create_tf_example(image_path, label, resize=None):
    #以二进制格式打开图片
    with tf.gfile.GFile(image_path, 'rb') as fid:
    encode_jpg = fid.read()
    encode_jpg_io = io.BytesIO(encode_jpg)
    image = Image.open(encode_jpg_io)
    # process png pic with four channels,将图片转为RGB
    image, process_flag = process_image_channels(image)
    # reshape image
    image = process_image_reshape(image, resize)
    if process_flag == True or resize is not None:
    bytes_io = io.BytesIO()
    image.save(bytes_io, format='JPEG')
    encoded_jpg = bytes_io.getvalue()
    width, height = image.size
    tf_example = tf.train.Example(
    features=tf.train.Features(
    feature={
    'image/encoded': bytes_feature(encode_jpg),
    'image/format': bytes_feature(b'jpg'),
    'image/class/label': int64_feature(label),
    'image/height': int64_feature(height),
    'image/width': int64_feature(width)
    }
    ))
    return tf_example


    def generate_tfrecord(annotation_dict, record_path, resize=None):
    num_tf_example = 0
    #writer就是我们TFrecord生成器
    writer = tf.python_io.TFRecordWriter(record_path)
    for image_path, label in annotation_dict.items():
    #tf.gfile.GFile获取文本操作句柄,类似于python提供的文本操作open()函数
    #filename是要打开的文件名,mode是以何种方式去读写,将会返回一个文本操作句柄。
    if not tf.gfile.GFile(image_path):
    print("{} does not exist".format(image_path))
    tf_example = create_tf_example(image_path, label, resize)
    #tf_example.SerializeToString()是将Example中的map压缩为二进制文件
    writer.write(tf_example.SerializeToString())
    num_tf_example += 1
    if num_tf_example % 100 == 0:
    print("Create %d TF_Example" % num_tf_example)
    writer.close()
    print("{} tf_examples has been created successfully, which are saved in {}".format(num_tf_example, record_path))


    def main(_):
    word2number_dict = {
    "combinations": 0,
    "details": 1,
    "sizes": 2,
    "tags": 3,
    "models": 4,
    "tileds": 5,
    "hangs": 6
    }
    # 图片路径
    images_dir = FLAGS.images_dir
    #annotation_path = FLAGS.annotation_path
    #生成TFRecord文件的路径
    record_path = FLAGS.record_path
    annotation_dict = get_annotation_dict(images_dir, word2number_dict)
    generate_tfrecord(annotation_dict, record_path)


    if __name__ == '__main__':
    tf.app.run()

     总结:1.制作数据(图片路径和标签)

        2.利用tf.python_io.TFRecordWriter创建一个writer,就是我们TFrecord生成器

        3.遍历数据集,以二进制形式打开图片

        4.利用tf.train.Example将图片,图片格式,标签和长宽进行保存

        5然后利用writer.write(tf_example.SerializeToString())将tf.train.Example存储的数据格式写入TFRecord即可



    参考链接:https://www.jianshu.com/p/b480e5fcb638
  • 相关阅读:
    node 跨域请求设置
    iOS下如何阻止橡皮筋效果
    您的手机上未安装应用程序 android 点击快捷方式提示未安装程序的解决
    您的手机上未安装应用程序 android 点击快捷方式提示未安装程序的解决
    ImageButton的坑 ImageButton 有问题
    ImageButton的坑 ImageButton 有问题
    ImageButton的坑 ImageButton 有问题
    textView代码设置文字居中失效 textView设置文字居中两种方法
    textView代码设置文字居中失效 textView设置文字居中两种方法
    textView代码设置文字居中失效 textView设置文字居中两种方法
  • 原文地址:https://www.cnblogs.com/lzq116/p/12029421.html
Copyright © 2011-2022 走看看