zoukankan      html  css  js  c++  java
  • TFRecord文件

    对于数据进行统一的管理是很有必要的.TFRecord就是对于输入数据做统一管理的格式.加上一些多线程的处理方式,使得在训练期间对于数据管理把控的效率和舒适度都好于暴力的方法.
    小的任务什么方法差别不大,但是对于大的任务,使用统一格式管理的好处就非常显著了.因此,TFRecord的使用方法很有必要熟悉.

    一.生成TFrecords

    tf.python_io.TFRecordWriter 类
    把记录写入到TFRecords文件的类.

    __init__(path,options=None)

    作用:创建一个TFRecordWriter对象,这个对象就负责写记录到指定的文件中去了.
    参数:
    path: TFRecords 文件路径
    options: (可选) TFRecordOptions对象

    close()

    作用:关闭对象.

    write(record)

    作用:把字符串形式的记录写到文件中去.
    参数:
    record: 字符串,待写入的记录

    Ⅱ.tf.train.Example
    这个类是非常重要的,TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的.

    函数:
    __init__(**kwargs)

    这个函数是初始化函数,会生成一个Example对象,一般我们使用的时候,是传入一个tf.train.Features对象进去.

    SerializeToString()

    作用:把example序列化为一个字符串,因为在写入到TFRcorde的时候,write方法的参数是字符串的.

    .tf.train.Features

    函数:
    __init__(**kwargs)
    作用:初始化Features对象,一般我们是传入一个字典,字典的键是一个字符串,表示名字,字典的值是一个tf.train.Feature对象.

    Ⅳ.tf.train.Feature
    class tf.train.Feature

    属性:
    bytes_list
    float_list
    int64_list

    函数:
    __init__(**kwargs)

    作用:构造一个Feature对象,一般使用的时候,传入 tf.train.Int64List, tf.train.BytesList, tf.train.FloatList对象.

    Ⅴ.tf.train.Int64List, tf.train.BytesList, tf.train.FloatList
    使用的时候,一般传入一个具体的值,比如学习任务中的标签就可以传进value=tf.train.Int64List,而图片就可以先转为字符串的格式之后,传入value=tf.train.BytesList中.

    存入TFRecords文件需要数据先存入名为example的protocol buffer,然后将其serialize成为string才能写入。example中包含features,用于描述数据类型:bytes,float,int64。

    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
    
    train_filename = 'train.tfrecords'
    with tf.python_io.TFRecordWriter(train_filename) as tfrecord_writer:  
        for i in range(len(images)):
            # read in image data by tf
            img_data = tf.gfile.FastGFile(images[i], 'rb').read()  # image data type is string
            label = labels[i]
            # get width and height of image
            image_shape = cv2.imread(images[i]).shape
            width = image_shape[1]
            height = image_shape[0]
            # create features
            feature = {'train/image': _bytes_feature(img_data),
                               'train/label': _int64_feature(label),  # label: integer from 0-N
                               'train/height': _int64_feature(height), 
                               'train/width': _int64_feature(width)}
            # create example protocol buffer
            example = tf.train.Example(features=tf.train.Features(feature=feature))
            # serialize protocol buffer to string
            tfrecord_writer.write(example.SerializeToString())
     tfrecord_writer.close()
    img_raw = img.tobytes()#将图片转化为二进制格式
    # 为图像建Example
    example = tf.train.Example(features=tf.train.Features(feature={
    'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
    }))
    # 写入tfrecord文件
    num_pic += 1
    writer.write(example.SerializeToString())



    二  Tensorflow读写TFRecords文件

    在使用slim之类的tensorflow自带框架的时候一般默认的数据格式就是TFRecords,在训练的时候使用TFRecords中数据的流程如下

    使用input pipeline读取tfrecords文件/其他支持的格式,然后随机乱序,生成文件序列,读取并解码数据,输入模型训练。

    如果有一串jpg图片地址和相应的标签:imageslabels

    首先用tf.train.string_input_producer读取tfrecords文件的list建立FIFO序列,可以申明num_epoches和shuffle参数表示需要读取数据的次数以及时候将tfrecords文件读入顺序打乱,然后定义TFRecordReader读取上面的序列返回下一个record,用tf.parse_single_example对读取到TFRecords文件进行解码,根据保存的serialize example和feature字典返回feature所对应的值。此时获得的值都是string,需要进一步解码为所需的数据类型。把图像数据的string reshape成原始图像后可以进行preprocessing操作。此外,还可以通过tf.train.batch或者tf.train.shuffle_batch将图像生成batch序列。

    由于tf.train函数会在graph中增加tf.train.QueueRunner类,而这些类有一系列的enqueue选项使一个队列在一个线程里运行。为了填充队列就需要用tf.train.start_queue_runners来为所有graph中的queue runner启动线程,而为了管理这些线程就需要一个tf.train.Coordinator来在合适的时候终止这些线程。

     

     

    import tensorflow as tf
    import matplotlib.pyplot as plt
    
    data_path = 'train.tfrecords'
    
    with tf.Session() as sess:
        # feature key and its data type for data restored in tfrecords file
        feature = {'train/image': tf.FixedLenFeature([], tf.string),
                         'train/label': tf.FixedLenFeature([], tf.int64),
                         'train/height': tf.FixedLenFeature([], tf.int64),
                         'train/width': tf.FixedLenFeature([], tf.int64)}
        # define a queue base on input filenames
        filename_queue = tf.train.string_input_producer([data_path], num_epoches=1)
        # define a tfrecords file reader
        reader = tf.TFRecordReader()
        # read in serialized example data
        _, serialized_example = reader.read(filename_queue)
        # decode example by feature
        features = tf.parse_single_example(serialized_example, features=feature)
        image = tf.image.decode_jpeg(features['train/image'])
        image = tf.image.convert_image_dtype(image, dtype=tf.float32)  # convert dtype from unit8 to float32 for later resize
        label = tf.cast(features['train/label'], tf.int64)
        height = tf.cast(features['train/height'], tf.int32)
        width = tf.cast(features['train/width'], tf.int32)
        # restore image to [height, width, 3]
        image = tf.reshape(image, [height, width, 3])
        # resize
        image = tf.image.resize_images(image, [224, 224])
        # create bathch
        images, labels = tf.train.shuffle_batch([image, label], batch_size=10, capacity=30, num_threads=1, min_after_dequeue=10) # capacity是队列的最大容量,
    #min_after_dequeue是dequeue后最小的队列大小,num_threads是进行队列操作的线程数。
    # initialize global & local variables init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) # create a coordinate and run queue runner objects coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for batch_index in range(3): batch_images, batch_labels = sess.run([images, labels]) for i in range(10): plt.imshow(batch_images[i, ...]) plt.show() print "Current image label is: ", batch_lables[i] # close threads coord.request_stop() coord.join(threads) sess.close()

    tf.decode_raw与tf.cast的区别

    tf.decode_raw函数的意思是将原来编码为字符串类型的变量重新变回来,这个方法在数据集dataset中很常用,

    因为制作图片源数据一般写进tfrecord里用to_bytes的形式,也就是字符串。这里将原始数据取出来 必须制定原始数据的格式,原始数据是什么格式这里解析必须是什么格式,要不然会出现形状的不对应问题!

    例如原始数据是tf.float64然后to_bytes,但是用tf.decode_raw解析的时候使用了tf.float32,那么形状跟值都会跟原始数据有差别,后面传入网络的时候一定会报

    tensorflow : Input to reshape is a tensor with 16384 values, but the requested shape has 49152 这种错误

    tf.cast

    这个函数主要用于数据类型的转变,不会改变原始数据的值还有形状的,

    retyped_images = tf.cast(decoded_images, tf.float32)

    labels = tf.cast(features['label'],tf.int32)

    这里retyped_images原来是tf.float64形状 labels是tf.uint8。tf.cast还可以用于将numpy数组转化为tensor

    tf.decode_raw()解析固定长度的数据,对于数据格式有一定的要求,应为tf.uint8

    import tensorflow as tf import cv2 def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) train_filename = 'train.tfrecords'with tf.python_io.TFRecordWriter(train_filename) as tfrecord_writer: for i in range(len(images)): # read in image data by tf img_data = tf.gfile.FastGFile(images[i], 'rb').read() # image data type is string label = labels[i] # get width and height of image image_shape = cv2.imread(images[i]).shape width = image_shape[1] height = image_shape[0] # create features feature = {'train/image': _bytes_feature(img_data), 'train/label': _int64_feature(label), # label: integer from 0-N'train/height': _int64_feature(height), 'train/width': _int64_feature(width)} # create example protocol buffer example = tf.train.Example(features=tf.train.Features(feature=feature)) # serialize protocol buffer to string tfrecord_writer.write(example.SerializeToString()) tfrecord_writer.close()

  • 相关阅读:
    Android网站
    vim里面搜索字符串
    ssd遇到的bug
    ssd训练自己的数据集
    slover层解读
    caffe LOG LOG_IF
    cuda输出
    css中合理的使用nth-child实现布局
    Linux VM环境配置
    怎样对Android设备进行网络抓包
  • 原文地址:https://www.cnblogs.com/tingtin/p/12539031.html
Copyright © 2011-2022 走看看