zoukankan      html  css  js  c++  java
  • Tensorflow踩坑系列---数据读取文件队列

    一:总括文件读取方式

    1.供给数据(Feeding): 由占位符placeholder代替数据,运行时使用feed_dict填入数据

    2.预加载数据: 数据直接嵌入graph,由graph传入session中运行

     

    3.从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据,这就是这篇文将要讲的内容。

    前两种方法很方便,但是遇到大型数据的时候就会很吃力,即使是Feeding,中间环节的增加也是不小的开销,比如数据类型转换等等。 最优的方案就是在Graph定义好文件读取的方法,让TF自己去从文件中读取数据,并解码成可使用的样本集。

    对于大的数据集很难用numpy数组保存,所以这里介绍一下Tensorflow读取很大数据集的方法:string_input_producer()和slice_input_producer()。

    这种直接从文件中读取数据的方式需要设计成Queue的方式才能较好的解决IO瓶颈的问题。
    Queue机制有如下三个特点:

    (1)producer-consumer pattern(生产消费模式)
    (2)独立于主线程执行
    (3)异步IO: reader.read(queue) tf.train.batch()

    一:string_input_producer队列使用(单个Reader、单文件读取)

    import tensorflow as tf
    IMAGE_DIR = "./Images/SourceImgs/"
    QUEUE_DIR = "./Images/QueueImgs/"
    FILELIST = ["1100.jpg","1101.jpg","1102.jpg","1104.jpg","1105.jpg",
               "1110.jpg","1114.jpg","1115.jpg","1116.jpg","1118.jpg"]

    (一)获取文件列表

    def getFileList(rootDir=IMAGE_DIR,files=FILELIST):
        fsl = []
        for fn in files:
            fsl.append(IMAGE_DIR+fn)
        return fsl

    (二)使用队列读取文件

    with tf.Session() as sess:
        files_list = getFileList()
        #string_input_producer产生文件名队列
        filename_queue = tf.train.string_input_producer(files_list,shuffle=True,num_epochs=3)
        #reader从文件名队列中读取数据
        reader = tf.WholeFileReader()
        key,value = reader.read(filename_queue) #返回文件名和文件内容
        
        sess.run(tf.local_variables_initializer()) #初始化上面的局部变量
        
        #启动start_queue_runners之后,才会开始填充队列
        threads = tf.train.start_queue_runners(sess=sess)
        i = 1
        while True:
            try:
                image_data = sess.run(value)
                with open(QUEUE_DIR+"%d.jpg"%i,"wb") as f:
                    f.write(image_data)
                i+=1
            except BaseException:
                print("read all files, numbers:%d"%i)
                break

    (三)参数说明

    tf.train.string_input_producer(files_list,shuffle=False,num_epochs=2)

    shuffle=False:表示按序获得文件

    num_epochs=2:表示会遍历两遍全部文件,当我们不设置数值的时候,表示我们可以一直遍历下去,会循环所有文件

    tf.train.string_input_producer(files_list,shuffle=True,num_epochs=3)

    shuffle=False:表示打乱顺序获得文件(是本轮所有文件列表中乱序,不是全局)

    num_epochs=2:表示会遍历三遍全部文件

    二:string_input_producer队列使用(单个Reader、批文件读取)

    import tensorflow as tf
    IMAGE_DIR = "./Images/SourceImgs/"
    QUEUE_DIR = "./Images/QueueImgs/"
    FILELIST = ["1100.jpg","1101.jpg","1102.jpg","1104.jpg","1105.jpg",
               "1110.jpg","1114.jpg","1115.jpg","1116.jpg","1118.jpg"]
    def getFileList(rootDir=IMAGE_DIR,files=FILELIST):
        fsl = []
        for fn in files:
            fsl.append(IMAGE_DIR+fn)
        return fsl

    (一)按批次获取文件

    files_list = getFileList()
    #string_input_producer产生文件名队列
    filename_queue = tf.train.string_input_producer(files_list,shuffle=False,num_epochs=1)
    
    def decode_img(fileQueue):
        #reader从文件名队列中读取数据
        reader = tf.WholeFileReader()
        key,value = reader.read(fileQueue) #返回文件名和文件内容
        return value #返回一个文件
    
    img = decode_img(filename_queue)
    
    image_batch = tf.train.batch([img],batch_size=8,num_threads=2,allow_smaller_final_batch=True) 

    (二)线程调用

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer()) #初始化上面的全局变量
        sess.run(tf.local_variables_initializer()) #初始化上面的局部变量
        
        coord = tf.train.Coordinator()
        #启动start_queue_runners之后,才会开始填充队列
        threads = tf.train.start_queue_runners(sess=sess,coord=coord)
        j = 1
        try:
            while not coord.should_stop():
                images_data = sess.run(image_batch)
                print(images_data.shape)
                for img_data in images_data:
                    with open(QUEUE_DIR+"%d.jpg"%j,"wb") as f:
                        f.write(img_data)
                    j+=1
        except BaseException:
                print("read all files")
        finally:
            coord.request_stop() #将读取文件的线程关闭
        coord.join(threads) #线程回收,将读取文件的子线程加入主线程

    (三)参数说明

    tf.train.batch([img],batch_size=8,num_threads=2,allow_smaller_final_batch=True)

    使用tf.train.batch,按序获取:

    batch_size每一个批次大小为8,
    num_threads使用2线程读取数据,虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。
    allow_smaller_final_batch,默认为false,剩余数据小于batch_size则会被丢弃。

    tf.train.shuffle_batch() 将队列中数据打乱后再读取出来,其他与batch方法类似。

    需要设置:
    capacity:队列中元素的最大数量。
    min_after_dequeue:出队后队列中元素的最小数量,用于确保元素的混合级别。

    补充:

    TensorFlow学习--tf.train.batch与tf.train.shuffle_batch

    tf.train.string_input_producer()和tf.train.slice_input_producer()

    string_input_producer:

    加载图片的reader是reader = tf.WholeFileReader()

    key,value = reader.read(path_queue)其中key是文件名,value是byte类型的文件流二进制。

    slice_input_producer:

    加载图片的reader使用tf.read_file(filename)直接读取。这是两者的一个不同之处!!!

    TensorFlow基础3:数据读取的三种方式

  • 相关阅读:
    python模块安装路径
    yum软件搜索
    项目里用到的python知识点
    python调用C函数
    opencv VideoCapture使用示例
    Django模型层之多表操作
    博客园 装饰
    mysql条件查询-排除null ---oracle、mysql 区分总结
    Android——Fragment详解
    Android——监听事件总结
  • 原文地址:https://www.cnblogs.com/ssyfj/p/13974386.html
Copyright © 2011-2022 走看看