zoukankan      html  css  js  c++  java
  • tensorflow制作tfrecord格式数据

    tf.Example msg

    tensorflow提供了一种统一的格式.tfrecord来存储图像数据.用的是自家的google protobuf.就是把图像数据序列化成自定义格式的二进制数据.

    To read data efficiently it can be helpful to serialize your data and store it in a set of files (100-200MB each) that can each be read linearly. This is especially true if the data is being streamed over a network. This can also be useful for caching any data-preprocessing.

    The TFRecord format is a simple format for storing a sequence of binary records.
    protobuf消息的格式如下:
    https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/core/example/feature.proto

    message BytesList {
      repeated bytes value = 1;
    }
    message FloatList {
      repeated float value = 1 [packed = true];
    }
    message Int64List {
      repeated int64 value = 1 [packed = true];
    }
    
    // Containers for non-sequential data.
    message Feature {
      // Each feature can be exactly one kind.
      oneof kind {
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
      }
    };
    
    message Features {
      map<string, Feature> feature = 1;
    };
    
    message FeatureList {
      repeated Feature feature = 1;
    };
    
    message FeatureLists {
      map<string, FeatureList> feature_list = 1;
    };
    

    tf.Example是一个map. map的格式为{"string": tf.train.Feature}
    tf.train.Feature基本的格式有3种:

    • tf.train.BytesList
      • string
      • byte
    • tf.train.FloatList
      • float(float32)
      • double(float64)
    • tf.train.Int64List
      • bool
      • enum
      • int32
      • unit32
      • int64
      • uint64

    参考tensorflow官方文档

    将自己的数据制作为tfrecord格式

    完整代码

    from __future__ import absolute_import, division, print_function, unicode_literals
    import tensorflow as tf
    import numpy as np
    import IPython.display as display
    import os
    import cv2 as cv
    import argparse
    
    def _bytes_feature(value):
      """Returns a bytes_list from a string / byte."""
      if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
      return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    def _float_feature(value):
      """Returns a float_list from a float / double."""
      return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
    
    def _int64_feature(value):
      """Returns an int64_list from a bool / enum / int / uint."""
      return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    def convert_to_tfexample(img_data,label,height=320,width=320):
        """convert one img matrix into tf.Example"""
        image_format = 'png'
        
        example = tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': _bytes_feature(img_data),
        'image/format': _bytes_feature(tf.compat.as_bytes(image_format)),
        'image/class/label': _int64_feature(label),
        'image/height': _int64_feature(height),
        'image/width': _int64_feature(width),
        }))
        
        return example
    
    #path="/home/sc/disk/data/lishui/1"
    def read_dataset(path):
        imgs=[]
        labels=[]
        for root, dirs, files in os.walk(path):
            for one_file in files:
                #print(os.path.join(path,one_file))
                one_file = os.path.join(path,one_file)
                if one_file.endswith("png"):
                    label_file = one_file.replace('png','txt')
                    if not os.path.isfile(label_file):
                        continue
    
                    f = open(label_file)
                    class_index = int(f.readline().split(' ')[0])
                    labels.append(class_index)
    
                    img =  tf.gfile.GFile(one_file, 'rb').read() 
                    print(type(img))
                    imgs.append(img)
    
        return imgs,labels
    
    def arg_parse():
        parser = argparse.ArgumentParser()
        #parser.add_argument('--help',help='ex:python create_tfrecord.py -d /home/sc/disk/data/lishui/1 -o train.tfrecord')
        parser.add_argument('-d','--dir',type=str,default='./data',required='True',help='dir store images/label file')
        parser.add_argument('-o','--output',type=str,default='./outdata.tfrecord',required='True',help='output tfrecord file name')
    
        args = parser.parse_args()
        
        return args
    
    def main():
        args = arg_parse()
        
        writer = tf.io.TFRecordWriter(args.output)
    
        imgs,labels = read_dataset(args.dir)   
        examples = map(convert_to_tfexample,imgs,labels)
        for example in examples:
            writer.write(example.SerializeToString())
        writer.close()
    
        print("write done")
    
    if __name__ == '__main__':
        """
        usage:python create_tfrecord.py [data_path] [outrecordfile_path]
        ex:python create_tfrecord.py -d /home/sc/disk/data/lishui/1 -o train.tfrecord
        """
        main()
    

    首先就是需要有工具函数把byte/string/float/int..等等类型的数据转换为tf.train.Feature

    def _bytes_feature(value):
      """Returns a bytes_list from a string / byte."""
      if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
      return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    def _float_feature(value):
      """Returns a float_list from a float / double."""
      return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
    
    def _int64_feature(value):
      """Returns an int64_list from a bool / enum / int / uint."""
      return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    

    接下来,对于图片矩阵和标签数据,我们调用上述工具函数,将单幅图片及其标签信息转换为tf.ttrain.Example消息.

    def convert_to_tfexample(img,label):
        """convert one img matrix into tf.Example"""
        img_raw = img.tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
        'label': _int64_feature(label),
        'img': _bytes_feature(img_raw)}))
        
        return example
    

    对于我的数据,图片以及label文件位于同一目录.比如dir下有图片a.png及相应的标签信息a.txt.

    def read_dataset(path):
        imgs=[]
        labels=[]
        for root, dirs, files in os.walk(path):
            for one_file in files:
                #print(os.path.join(path,one_file))
                one_file = os.path.join(path,one_file)
                if one_file.endswith("png"):
                    label_file = one_file.replace('png','txt')
                    if not os.path.isfile(label_file):
                        continue
    
                    f = open(label_file)
                    class_index = int(f.readline().split(' ')[0])
                    labels.append(class_index)
    
                    img =  tf.gfile.GFile(one_file, 'rb').read() 
                    print(type(img))
                    imgs.append(img)
    
        return imgs,labels
    

    遍历data目录,完成图片读取,及label读取. 如果你的数据不是这么存放的,就修改这个函数好了,返回值仍然是imgs,labels

    最后就是调用 tf.io.TFRecordWriter将每一个tf.train.Example消息写入文件保存.

    def main():
        args = arg_parse()
        
        writer = tf.io.TFRecordWriter(args.output)
        #path="/home/sc/disk/data/lishui/1"
    
        imgs,labels = read_dataset(args.dir)   
        examples = map(convert_to_tfexample,imgs,labels)
        for example in examples:
            writer.write(example.SerializeToString())
        writer.close()
    
        print("write done")
    
  • 相关阅读:
    心得体悟帖---200130(专业长才(敲门砖))(希望)
    心得体悟帖---200130(一举多的)(少了发自内心的从容)
    范仁义css3课程---19、流动布局
    范仁义css3课程---18、overflow
    日常英语---200130(inspire)
    日常英语---200130(Basketball fans around the world are mourning the death of American superstar Kobe Bryant.)
    视频中的ts文件是什么
    如何美化windows桌面
    心得体悟帖---200127(囚笼-它会推着我的,不必多想)(过好当下,享受当下)
    心得体悟帖---有哪些越早知道越好的人生经验?(转自知乎)
  • 原文地址:https://www.cnblogs.com/sdu20112013/p/11820140.html
Copyright © 2011-2022 走看看