zoukankan      html  css  js  c++  java
  • 理解TensorFlow的Queue

    https://www.jianshu.com/p/d063804fb272

    这篇文章来说说TensorFlow里与Queue有关的概念和用法。

    其实概念只有三个:

    • Queue是TF队列和缓存机制的实现
    • QueueRunner是TF中对操作Queue的线程的封装
    • Coordinator是TF中用来协调线程运行的工具

    虽然它们经常同时出现,但这三样东西在TensorFlow里面是可以单独使用的,不妨先分开来看待。

    1. Queue

    根据实现的方式不同,分成具体的几种类型,例如:

    • tf.FIFOQueue 按入列顺序出列的队列
    • tf.RandomShuffleQueue 随机顺序出列的队列
    • tf.PaddingFIFOQueue 以固定长度批量出列的队列
    • tf.PriorityQueue 带优先级出列的队列
    • ... ...

    这些类型的Queue除了自身的性质不太一样外,创建、使用的方法基本是相同的。

    创建函数的参数:

    tf.FIFOQueue(capacity, dtypes, shapes=None, names=None ...)
    

    Queue主要包含入列(enqueue)出列(dequeue)两个操作。enqueue操作返回计算图中的一个Operation节点,dequeue操作返回一个Tensor值。Tensor在创建时同样只是一个定义(或称为“声明”),需要放在Session中运行才能获得真正的数值。下面是一个单独使用Queue的例子:

    import tensorflow as tf
    tf.InteractiveSession()
    
    q = tf.FIFOQueue(2, "float")
    init = q.enqueue_many(([0,0],))
    
    x = q.dequeue()
    y = x+1
    q_inc = q.enqueue([y])
    
    init.run()
    q_inc.run()
    q_inc.run()
    q_inc.run()
    x.eval()  # 返回1
    x.eval()  # 返回2
    x.eval()  # 卡住
    

    注意,如果一次性入列超过Queue Size的数据,enqueue操作会卡住,直到有数据(被其他线程)从队列取出。对一个已经取空的队列使用dequeue操作也会卡住,直到有新的数据(从其他线程)写入。

    2. QueueRunner

    Tensorflow的计算主要在使用CPU/GPU和内存,而数据读取涉及磁盘操作,速度远低于前者操作。因此通常会使用多个线程读取数据,然后使用一个线程消费数据。QueueRunner就是来管理这些读写队列的线程的。

    QueueRunner需要与Queue一起使用(这名字已经注定了它和Queue脱不开干系),但并不一定必须使用Coordinator。看下面这个例子:

    import tensorflow as tf  
    import sys  
    q = tf.FIFOQueue(10, "float")  
    counter = tf.Variable(0.0)  #计数器
    # 给计数器加一
    increment_op = tf.assign_add(counter, 1.0)
    # 将计数器加入队列
    enqueue_op = q.enqueue(counter)
    
    # 创建QueueRunner
    # 用多个线程向队列添加数据
    # 这里实际创建了4个线程,两个增加计数,两个执行入队
    qr = tf.train.QueueRunner(q, enqueue_ops=[increment_op, enqueue_op] * 2)
    
    # 主线程
    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()
    # 启动入队线程
    qr.create_threads(sess, start=True)
    for i in range(20):
        print (sess.run(q.dequeue()))
    

    增加计数的进程会不停的后台运行,执行入队的进程会先执行10次(因为队列长度只有10),然后主线程开始消费数据,当一部分数据消费被后,入队的进程又会开始执行。最终主线程消费完20个数据后停止,但其他线程继续运行,程序不会结束。

    3. Coordinator

    Coordinator是个用来保存线程组运行状态的协调器对象,它和TensorFlow的Queue没有必然关系,是可以单独和Python线程使用的。例如:

    import tensorflow as tf
    import threading, time
    
    # 子线程函数
    def loop(coord, id):
        t = 0
        while not coord.should_stop():
            print(id)
            time.sleep(1)
            t += 1
            # 只有1号线程调用request_stop方法
            if (t >= 2 and id == 1):
                coord.request_stop()
    
    # 主线程
    coord = tf.train.Coordinator()
    # 使用Python API创建10个线程
    threads = [threading.Thread(target=loop, args=(coord, i)) for i in range(10)]
    
    # 启动所有线程,并等待线程结束
    for t in threads: t.start()
    coord.join(threads)
    

    将这个程序运行起来,会发现所有的子线程执行完两个周期后都会停止,主线程会等待所有子线程都停止后结束,从而使整个程序结束。由此可见,只要有任何一个线程调用了Coordinator的request_stop方法,所有的线程都可以通过should_stop方法感知并停止当前线程。

    将QueueRunner和Coordinator一起使用,实际上就是封装了这个判断操作,从而使任何一个现成出现异常时,能够正常结束整个程序,同时主线程也可以直接调用request_stop方法来停止所有子线程的执行。

    4. 在一起

    在TensorFlow中用Queue的经典模式有两种,都是配合了QueueRunner和Coordinator一起使用的。

    第一种,显式的创建QueueRunner,然后调用它的create_threads方法启动线程。例如下面这段代码:

    import tensorflow as tf
    
    # 1000个4维输入向量,每个数取值为1-10之间的随机数
    data = 10 * np.random.randn(1000, 4) + 1
    # 1000个随机的目标值,值为0或1
    target = np.random.randint(0, 2, size=1000)
    
    # 创建Queue,队列中每一项包含一个输入数据和相应的目标值
    queue = tf.FIFOQueue(capacity=50, dtypes=[tf.float32, tf.int32], shapes=[[4], []])
    
    # 批量入列数据(这是一个Operation)
    enqueue_op = queue.enqueue_many([data, target])
    # 出列数据(这是一个Tensor定义)
    data_sample, label_sample = queue.dequeue()
    
    # 创建包含4个线程的QueueRunner
    qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)
    
    with tf.Session() as sess:
        # 创建Coordinator
        coord = tf.train.Coordinator()
        # 启动QueueRunner管理的线程
        enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
        # 主线程,消费100个数据
        for step in range(100):
            if coord.should_stop():
                break
            data_batch, label_batch = sess.run([data_sample, label_sample])
        # 主线程计算完成,停止所有采集数据的进程
        coord.request_stop()
        coord.join(enqueue_threads)
    

    第二种,使用全局的start_queue_runners方法启动线程。

    import tensorflow as tf
    
    # 同时打开多个文件,显示创建Queue,同时隐含了QueueRunner的创建
    filename_queue = tf.train.string_input_producer(["data1.csv","data2.csv"])
    reader = tf.TextLineReader(skip_header_lines=1)
    # Tensorflow的Reader对象可以直接接受一个Queue作为输入
    key, value = reader.read(filename_queue)
    
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        # 启动计算图中所有的队列线程
        threads = tf.train.start_queue_runners(coord=coord)
        # 主线程,消费100个数据
        for _ in range(100):
            features, labels = sess.run([data_batch, label_batch])
        # 主线程计算完成,停止所有采集数据的进程
        coord.request_stop()
        coord.join(threads)
    

    在这个例子中,tf.train.string_input_produecer会将一个隐含的QueueRunner添加到全局图中(类似的操作还有tf.train.shuffle_batch等)。

    由于没有显式地返回QueueRunner来用create_threads启动线程,这里使用了tf.train.start_queue_runners方法直接启动tf.GraphKeys.QUEUE_RUNNERS集合中的所有队列线程。

    这两种方式在效果上是等效的。

    参考文章

    1. tensorflow中关于队列使用的实验
    2. cs20si课件slides_09


    作者:巾梵
    链接:https://www.jianshu.com/p/d063804fb272
    來源:简书
    简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。
  • 相关阅读:
    网络编程
    mysql
    python 基础
    vim 操作
    linux 基本命令
    基本库使用(urllib,requests)
    震撼功能:逐浪CMS全面支持PWA移动生成意指未来
    硬件能力与智能AI-Zoomla!逐浪CMS2 x3.9.2正式发布
    从送外卖到建站售主机还有共享自行车说起-2017年8月江西IDC排行榜与发展报告
    HTTP协议知多少-关于http1.x、http2、SPDY的相关知识
  • 原文地址:https://www.cnblogs.com/DjangoBlog/p/9467867.html
Copyright © 2011-2022 走看看