zoukankan      html  css  js  c++  java
  • TensorFlowIO操作(二)----读取数据

    读取数据

    小数量数据读取

    这仅用于可以完全加载到存储器中的小的数据集有两种方法:

    • 存储在常数中。
    • 存储在变量中,初始化后,永远不要改变它的值。

    使用常数更简单一些,但是会使用更多的内存,因为常数会内联的存储在数据流图数据结构中,这个结构体可能会被复制几次。

    training_data = ...
    training_labels = ...
    with tf.Session():
      input_data = tf.constant(training_data)
      input_labels = tf.constant(training_labels)

    要改为使用变量的方式,您就需要在数据流图建立后初始化这个变量。

    training_data = ...
    training_labels = ...
    with tf.Session() as sess:
      data_initializer = tf.placeholder(dtype=training_data.dtype,
                                        shape=training_data.shape)
      label_initializer = tf.placeholder(dtype=training_labels.dtype,
                                         shape=training_labels.shape)
      input_data = tf.Variable(data_initalizer, trainable=False, collections=[])
      input_labels = tf.Variable(label_initalizer, trainable=False, collections=[])
      ...
      sess.run(input_data.initializer,
               feed_dict={data_initializer: training_data})
      sess.run(input_labels.initializer,
               feed_dict={label_initializer: training_lables})

    设定trainable=False可以防止该变量被数据流图的GraphKeys.TRAINABLE_VARIABLES收集,这样我们就不会在训练的时候尝试更新它的值;设定collections=[]可以防止GraphKeys.VARIABLES收集后做为保存和恢复的中断点。设定这些标志,是为了减少额外的开销

    文件读取

    先看下文件读取以及读取数据处理成张量结果的过程:

    一般数据文件格式有文本、excel和图片数据。那么TensorFlow都有对应的解析函数,除了这几种。还有TensorFlow指定的文件格式。

    标准TensorFlow格式

    TensorFlow还提供了一种内置文件格式TFRecord,二进制数据和训练类别标签数据存储在同一文件。模型训练前图像等文本信息转换为TFRecord格式。TFRecord文件是protobuf格式。数据不压缩,可快速加载到内存。TFRecords文件包含 tf.train.Example protobuf,需要将Example填充到协议缓冲区,将协议缓冲区序列化为字符串,然后使用该文件将该字符串写入TFRecords文件。在图像操作我们会介绍整个过程以及详细参数。

    数据读取实现

    文件队列生成函数

    • tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, name=None)

    产生指定文件张量

    文件阅读器类

    • class tf.TextLineReader

    阅读文本文件逗号分隔值(CSV)格式

    • tf.FixedLengthRecordReader

    要读取每个记录是固定数量字节的二进制文件

    • tf.TFRecordReader

    读取TfRecords文件

    解码

    由于从文件中读取的是字符串,需要函数去解析这些字符串到张量

    • tf.decode_csv(records,record_defaults,field_delim = None,name = None)将CSV转换为张量,与tf.TextLineReader搭配使用

    • tf.decode_raw(bytes,out_type,little_endian = None,name = None) 将字节转换为一个数字向量表示,字节为一字符串类型的张量,与函数tf.FixedLengthRecordReader搭配使用

    生成文件队列

    将文件名列表交给tf.train.string_input_producer函数。string_input_producer来生成一个先入先出的队列,文件阅读器会需要它们来取数据。string_input_producer提供的可配置参数来设置文件名乱序和最大的训练迭代数,QueueRunner会为每次迭代(epoch)将所有的文件名加入文件名队列中,如果shuffle=True的话,会对文件名进行乱序处理。一过程是比较均匀的,因此它可以产生均衡的文件名队列。

    这个QueueRunner工作线程是独立于文件阅读器的线程,因此乱序和将文件名推入到文件名队列这些过程不会阻塞文件阅读器运行。根据你的文件格式,选择对应的文件阅读器,然后将文件名队列提供给阅读器的 read 方法。阅读器的read方法会输出一个键来表征输入的文件和其中纪录(对于调试非常有用),同时得到一个字符串标量,这个字符串标量可以被一个或多个解析器,或者转换操作将其解码为张量并且构造成为样本。

    # 读取CSV格式文件
    # 1、构建文件队列
    
    # 2、构建读取器,读取内容
    
    # 3、解码内容
    
    # 4、现读取一个内容,如果有需要,就批处理内容
    import tensorflow as tf
    import os
    def readcsv_decode(filelist):
        """
        读取并解析文件内容
        :param filelist: 文件列表
        :return: None
        """
    
        # 把文件目录和文件名合并
        flist = [os.path.join("./csvdata/",file) for file in filelist]
    
        # 构建文件队列
        file_queue = tf.train.string_input_producer(flist,shuffle=False)
    
        # 构建阅读器,读取文件内容
        reader = tf.TextLineReader()
    
        key,value = reader.read(file_queue)
    
        record_defaults = [["null"],["null"]] # [[0],[0],[0],[0]]
    
        # 解码内容,按行解析,返回的是每行的列数据
        example,label = tf.decode_csv(value,record_defaults=record_defaults)
    
        # 通过tf.train.batch来批处理数据
        example_batch,label_batch = tf.train.batch([example,label],batch_size=9,num_threads=1,capacity=9)
    
    
        with tf.Session() as sess:
    
            # 线程协调员
            coord = tf.train.Coordinator()
    
            # 启动工作线程
            threads = tf.train.start_queue_runners(sess,coord=coord)
    
            # 这种方法不可取
            # for i in range(9):
            #     print(sess.run([example,label]))
    
            # 打印批处理的数据
            print(sess.run([example_batch,label_batch]))
    
    
            coord.request_stop()
    
            coord.join(threads)
    
        return None
    
    
    if __name__=="__main__":
        filename_list = os.listdir("./csvdata")
        readcsv_decode(filename_list)

    每次read的执行都会从文件中读取一行内容,注意,(这与后面的图片和TfRecords读取不一样),decode_csv操作会解析这一行内容并将其转为张量列表。如果输入的参数有缺失,record_default参数可以根据张量的类型来设置默认值。在调用run或者eval去执行read之前,你必须调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。

  • 相关阅读:
    9.3 寻找magic index
    编写一个函数,确定需要改变几个位,才能将整数A转成整数B。
    打印0-1之间double数字的二进制表示
    打印二叉树节点数值总和等于某个给定节点的所有路径
    判断T2是否是T1的子树
    二棵树某两个节点的公共祖先。
    4.6 找出二叉树中指定节点的下一个节点(中序后继),假定每个节点有父指针。
    队列实现max操作,要求尽量提高效率。
    用两个stack设计一个队列
    转分享一个MAC下绕开百度网盘限速下载的方法,三步操作永久生效
  • 原文地址:https://www.cnblogs.com/fwl8888/p/9794452.html
Copyright © 2011-2022 走看看