在使用TensorFlow进行异步计算时,队列是一种强大的机制。
为了感受一下队列,让我们来看一个简单的例子。我们先创建一个“先入先出”的队列(FIFOQueue),并将其内部所有元素初始化为零。然后,我们构建一个TensorFlow图,它从队列前端取走一个元素,加上1之后,放回队列的后端。慢慢地,队列的元素的值就会增加。
TensorFlow提供了两个类来帮助多线程的实现:tf.Coordinator和 tf.QueueRunner。Coordinator类可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常,QueueRunner类用来协调多个工作线程同时将多个张量推入同一个队列中。
队列概述
队列,如FIFOQueue和RandomShuffleQueue,在TensorFlow的张量异步计算时都非常重要。
例如,一个典型的输入结构:是使用一个RandomShuffleQueue来作为模型训练的输入:
-
多个线程准备训练样本,并且把这些样本推入队列。
-
一个训练线程执行一个训练操作
import tensorflow as tf # 模拟 同步处理数据后取数据 # define queue Q = tf.FIFOQueue(3,tf.float32) # input data 这里列表[] 会被当成一个张量,所以要在后面加一个逗号 enq_many = Q.enqueue_many([[0.1,0.2,0.3],]) # define take data from queue take data then +1 put this data back out_q = Q.dequeue() # tf中允许重载 data = out_q+1 en_q = Q.enqueue(data) with tf.Session() as sess: # 初始化队列 sess.run(enq_many) # 处理数据 for i in range(100): # 这里en_q 对data依赖data->out_1->enq_many sess.run(en_q) #训练数据 for i in range(Q.size().eval()): print(sess.run(Q.dequeue()))
QueueRunner 会创建一组线程, 这些线程可以重复的执行Enquene操作, 他们使用同一个Coordinator来处理线程同步终止。此外,一个QueueRunner会运行一个closer thread,当Coordinator收到异常报告时,这个closer thread会自动关闭队列。
您可以使用一个queue runner,来实现上述结构。 首先建立一个TensorFlow图表,这个图表使用队列来输入样本。增加处理样本并将样本推入队列中的操作。增加training操作来移除队列中的样本。
tf.Coordinator
Coordinator类用来帮助多个线程协同工作,多个线程同步终止。 其主要方法有:
-
should_stop():如果线程应该停止则返回True。
-
request_stop(): 请求该线程停止。
-
join():等待被指定的线程终止。
coord = tf.train.Coordinator() threads = qr.create_threads(sess,coord = coord,start=True) coord.request_stop() coord.join(threads)
import tensorflow as tf # 模拟异步 子线程存入样本,主线程读取样本 # 1 定一个队列 Q = tf.FIFOQueue(1000,tf.float32) # 2 定义子线程 +1 进队列 var = tf.Variable(0.0)# 变量op # assign_add 是自增 这里直接+=1 是不可以的 data = tf.assign_add(var,tf.constant(1.0))# 加法op en_q = Q.enqueue(data) # 队列管理器op qr = tf.train.QueueRunner(Q,enqueue_ops=[en_q]*2) with tf.Session() as sess: tf.global_variables_initializer().run() # 开启线程管理器 coord = tf.train.Coordinator() # 真正开启线程 threads = qr.create_threads(sess,coord=coord,start=True) for i in range(300): print(sess.run(Q.dequeue())) # 强行停止 coord.request_stop() coord.join(threads)