假设分布式任务包含n个ps节点, m个worker节点. m, n>0. 希望所有worker的任务结束后,所有节点才终止。
- 方法: 借助队列tf.FIFOQueue实现。
- 原理: tf.FIFOQueue 是个全局的的队列, 出队函数dequeue有这个特点:
If the queue is empty when this operation executes, it will block until there is an element to dequeue.
利用这个性质, 设置ps服务器的停止条件:- ps端执行m个出队列操作。 队列初始都是空队列, 因此,一开始出队操作都被阻塞。
- 每个worker完成任务后, 往ps的队列中放入一个元素,使得ps端的一个出队操作能执行完成。
- 参考: https://github.com/hn826/distributed-tensorflow/blob/master/distributed-deep-mnist-with-queue.py
更新
- 实际中, 可以定义全局变量, 通过判断全局变量状态控制终止条件。
* class GlobalStatus(object):
def __init__(self):
with tf.variable_scope("global_status", reuse=tf.AUTO_REUSE):
self.status = tf.get_variable("status", (), trainable=False,
dtype=tf.int32, initializer=tf.constant_initializer(0))
self.send_op = self.status.assign(1)
def change_status(self, sess):
sess.run(self.send_op)
def is_done(self, sess):
z = sess.run(self.status)
return z>0