zoukankan      html  css  js  c++  java
  • 3. Tensorflow生成TFRecord

    1. Tensorflow高效流水线Pipeline

    2. Tensorflow的数据处理中的Dataset和Iterator

    3. Tensorflow生成TFRecord

    4. Tensorflow的Estimator实践原理

    1. 前言

    TFRecord是TensorFlow官方推荐使用的数据格式化存储工具,它不仅规范了数据的读写方式,还大大地提高了IO效率。

    2. TFRecord原理步骤

    TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,只要生成一次TFRecord,之后的数据读取和加工处理的效率都会得到提高。

    而且,使用TFRecord可以直接作为Cloud ML Engine的输入数据。

    一般来说,我们使用TensorFlow进行数据读取的方式有以下4种:

    1. 预先把所有数据加载进内存
    2. 在每轮训练中使用原生Python代码读取一部分数据,然后使用feed_dict输入到计算图
    3. 利用Threading和Queues从TFRecord中分批次读取数据
    4. 使用Dataset API

    (1)方案对于数据量不大的场景来说是足够简单而高效的,但是随着数据量的增长,势必会对有限的内存空间带来极大的压力,还有长时间的数据预加载,甚至导致我们十分熟悉的OutOfMemoryError;

    (2)方案可以一定程度上缓解了方案(1)的内存压力问题,但是由于在单线程环境下我们的IO操作一般都是同步阻塞的,势必会在一定程度上导致学习时间的增加,尤其是相同的数据需要重复多次读取的情况下;

    而方案(3)和方案(4)都利用了我们的TFRecord,由于使用了多线程使得IO操作不再阻塞我们的模型训练,同时为了实现线程间的数据传输引入了Queues。

    2.1 生成TFRecord数据

    example = tf.train.Example()这句将数据赋给了变量example(可以看到里面是通过字典结构实现的赋值),然后用writer.write(example.SerializeToString()) 这句实现写入。

    tfrecord_filename = './tfrecord/train.tfrecord'
    # 创建.tfrecord文件,准备写入
    writer = tf.python_io.TFRecordWriter(tfrecord_filename)
    for i in range(100):
        img_raw = np.random.random_integers(0,255,size=(7,30)) # 创建7*30,取值在0-255之间随机数组
        img_raw = img_raw.tostring()
        example = tf.train.Example(features=tf.train.Features(
                feature={
                # Int64List储存int数据
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])), 
                # 储存byte二进制数据
                'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
                }))
        # 序列化过程
        writer.write(example.SerializeToString()) 
    writer.close()
    

    值得注意的是赋值给example的数据格式。从前面tf.train.Example的定义可知,tfrecord支持整型、浮点数和二进制三种格式,分别如下:

    tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))
    tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))
    tf.train.Feature(float_list = tf.train.FloatList(value=[float_scalar]))
    

    例如图片等数组形式(array)的数据,可以保存为numpy array的格式,转换为string,然后保存到二进制格式的feature中。对于单个的数值(scalar),可以直接赋值。这里value=[×]的[]非常重要,也就是说输入的必须是列表(list)。当然,对于输入数据是向量形式的,可以根据数据类型(float还是int)分别保存。并且在保存的时候还可以指定数据的维数。

    2.2 读取TFRecord数据

    从TFRecord文件中读取数据, 首先需要用tf.train.string_input_producer生成一个解析队列。之后调用tf.TFRecordReader的tf.parse_single_example解析器。如下图:

    image

    具体代码如下:

    def read_tfrecord(filename):
        filename_queue = tf.train.string_input_producer([filename])
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
    
        features = tf.parse_single_example(
            serialized_example,
            features={
                'sentence': tf.FixedLenFeature([], tf.string),
                'label': tf.FixedLenFeature([], tf.int64)
            })
    
        sentence, label = tf.train.batch([features['sentence'], features['label']],
                batch_size=16,
                capacity=64)
    
        return sentence, label
    

    3. 总结

    TFRecord的生成效率可能不是很快(可以使用多进程),但是一旦TFRecord数据处理好了,对以后每次的读取,解析都有速度上的提升。而且TFRecord也可以和Tensorflow自带的数据处理方式Dataset搭配使用,基本可以解决大数据量的训练操作。

  • 相关阅读:
    MySQL架构备份
    MySQL物理备份 xtrabackup
    MySQL物理备份 lvm-snapshot
    MySQL逻辑备份mysqldump
    MySQL逻辑备份into outfile
    MySQ数据备份
    前端基础-- HTML
    奇淫异巧之 PHP 后门
    php中代码执行&&命令执行函数
    windows进程中的内存结构(缓冲溢出原理)
  • 原文地址:https://www.cnblogs.com/huangyc/p/10339831.html
Copyright © 2011-2022 走看看