zoukankan      html  css  js  c++  java
  • 6、TensorFlow基础(四)队列和线程

    队列和线程

      和 TensorFlow 中的其他组件一样,队列(queue)本身也是图中的一个节点,是一种有状态的节点,其他节点,如入队节点(enqueue)和出队节点(dequeue),可以修改它的内容。例如,入队节点可以把新元素插到队列末尾,出队节点可以把队列前面的元素删除。本节主要介绍队列、队列管理器、线程和协调器的有关知识。

    1、队列:

      TensorFlow 中主要有两种队列,即 FIFOQueue 和 RandomShuffleQueue,它们的源代码实现在 tensorflow-1.1.0/tensorflow/python/ops/data_flow_ops.py 中。

      (1)、FIFOQueue

       FIFOQueue创建一个先入先出队列。列如,我们在训练一些语音、文字样本时,使用循环神经网络的网络结构,希望读入的训练样本是有序的,就要用FIFOQUEUE。

      我们行创建一个含有队列的图:

     1 # -*- coding: UTF-8 -*-
     2 # date:2018/6/22
     3 # User:WangHong
     4 import tensorflow as tf
     5 #创建一个先入先出的队列,初始化队列插入0.1,0.2,0.3三个数字
     6 q = tf.FIFOQueue(3,'float')
     7 init = q.enqueue_many(([0.1,0.2,0.3],))
     8 #定义出队、+1、入队操作
     9 x =q.dequeue()
    10 y = x+1
    11 q_inc = q.enqueue([y])
    12 #然后开启一个会话,执行2次q_inc操作,随后查看队列内容。
    13 with tf.Session() as sess:
    14     sess.run(init)
    15     quelen = sess.run(q.size())
    16     for i in range(2):
    17         sess.run(q_inc)#执行2次操作,队列中的值变为0.3,1.1,1.2
    18     quelen = sess.run(q.size())
    19     for i in range(quelen):
    20         print(sess.run(q.dequeue()))#输出队列的值

    结果:

         

      (2)、RandomShuffleQueue

        RandomShuffleQueue创建一个随机队列,在出队列时,是以随机的顺序产生元素的,例如,我们在训练一些图像样本是,使用CNN的网络结构,希望。可以无序的读入训练样本,就要用RandomShuffleQueue,每次随机产生一个训练样本。

        RandomShuffleQueue在在TensorFlow使用异步计算时很重要。因为TensorFlow的会话是支持多线程的,我们可以在主线程里执行训练操作,使用RandomShuffleQueue作为训练输入,开多线程来准备训练样本,将样本压入队列后,主线性会从线程中每次取出mini-batch的样本进行训练。

    例子;

     1 # -*- coding: UTF-8 -*-
     2 # date:2018/6/22
     3 # User:WangHong
     4 import tensorflow as tf
     5 q = tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes = 'float')
     6 #然后开启一个会话
     7 sess = tf.Session()
     8 for i in range(0,10):#10次入队
     9     sess.run(q.enqueue(i))
    10 for i in range(0,8):#8次出队
    11     print(sess.run(q.dequeue()))

    结果:发现结果是乱序

         

    我们尝试修改入队次数为 12 次,再运行,发现程序阻断不动,或者我们尝试修改出队此
    时为 10 次,即不保留队列最小长度,发现队列输出 8 次结果后,在终端仍然阻断了。

    阻断一般发生在:
    ● 队列长度等于最小值,执行出队操作;
    ● 队列长度等于最大值,执行入队操作。

      上面的例子都是在会话的主线程中进行入队操作。当数据量很大时,入队操作从硬盘中读
    取数据,放入内存中,主线程需要等待入队操作完成,才能进行训练操作。会话中可以运行多
    个线程,我们使用线程管理器 QueueRunner 创建一系列的新线程进行入队操作,让主线程继续
    使用数据,即训练网络和读取数据是异步的,主线程在训练网络,另一个线程在将数据从硬盘
    读入内存。

     2、队列管理器

      创建一个含有队列的图:

     1 # -*- coding: UTF-8 -*-
     2 # date:2018/6/22
     3 # User:WangHong
     4 import tensorflow as tf
     5 q = tf.FIFOQueue(1000,'float')
     6 counter = tf.Variable(0.0)#计数器
     7 increment_op = tf.assign_add(counter,tf.constant(1.0))#操作给计数器加一
     8 enqueue_op = q.enqueue(counter)#操作;计数器值加入队列
     9 #创建UI个队列计数器QueueRunner,用这两个操作向队列q添加元素。目前使用一个线程
    10 qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*1)
    11 #启动一个会话,从队列管理器qr中创建线程:
    12 #主线程
    13 with tf.Session() as sess:
    14     sess.run(tf.global_variables_initializer())
    15     enqueue_threads = qr.create_threads(sess,start=True)#启动入队线程
    16     for i in range(10):
    17         print(sess.run(q.dequeue()))

    结果:

         

      

      不是我们期待的自然数列,并且线程被阻断。这是因为加 1 操作和入队操作不同步,可能
    加 1 操作执行了很多次之后,才会进行一次入队操作。另外,因为主线程的训练(出队操作)
    和读取数据的线程的训练(入队操作)是异步的,主线程会一直等待数据送入。

      QueueRunner 有一个问题就是:入队线程自顾自地执行,在需要的出队操作完成之后,程
    序没法结束。这样就要使用 tf.train.Coordinator 来实现线程间的同步,终止其他线程。

     3、线程协调器

     1 # -*- coding: UTF-8 -*-
     2 # date:2018/6/22
     3 # User:WangHong
     4 import tensorflow as tf
     5 q = tf.FIFOQueue(1000,'float')
     6 counter = tf.Variable(0.0)#计数器
     7 increment_op = tf.assign_add(counter,tf.constant(1.0))#操作给计数器加一
     8 enqueue_op = q.enqueue(counter)#操作;计数器值加入队列
     9 #创建UI个队列计数器QueueRunner,用这两个操作向队列q添加元素。目前使用一个线程
    10 qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*1)
    11 #启动一个会话,从队列管理器qr中创建线程:
    12 #主线程
    13 with tf.Session() as sess:
    14     sess.run(tf.global_variables_initializer())
    15     coord = tf.train.Coordinator()
    16     enqueue_threads = qr.create_threads(sess,start=True)#启动入队线程
    17     coord.request_stop()  # 通知其他线程关闭
    18     for i in range(10):
    19         try:
    20             print(sess.run(q.dequeue()))
    21         except tf.errors.OutOfRangeError:
    22             break
    23     coord.join(enqueue_threads)#join操作等待其他线程结束,其他所有线程关闭之后,这一函数才能返回

    所有队列管理器被默认加在图的 tf.GraphKeys.QUEUE_RUNNERS 集合中。

  • 相关阅读:
    maven 笔记
    面试题53:在排序数组中查找数字
    面试题52:两个链表的第一个公共节点
    面试题51:数组中的逆序对
    面试题50_2:字符流中第一个只出现一次的字符
    面试题50:第一个只出现一次的字符
    面试题49:丑数
    面试题48:最长不含重复字符的连续子字符串
    面试题47:礼物的最大值
    面试题8:二叉树的下一个节点
  • 原文地址:https://www.cnblogs.com/wanshuai/p/9213422.html
Copyright © 2011-2022 走看看