zoukankan      html  css  js  c++  java
  • tf.train.batch and tf.train.shuffle_batch

    这俩方法都是从队列中批量获取元素,常用于样本的批量获取;

    这俩 API 非常反人类,有些参数我还没搞懂,时间关系,先学习常规用法吧

    batch

    从队列中获取指定个数的元素组成一个 batch

    def batch(tensors, batch_size, num_threads=1, capacity=32,
              enqueue_many=False, shapes=None, dynamic_pad=False,
              allow_smaller_final_batch=False, shared_name=None, name=None):
      """Creates batches of tensors in `tensors`."""

    tensors:队列

    batch_size:获取元素个数

    capacity:队列容量    【没搞懂有啥用】

    label = np.asarray(range(0, 100))
    # label = tf.cast(label, tf.int32)
    input_queue = tf.train.slice_input_producer([label], shuffle=False)
    label_batch = tf.train.batch(input_queue, batch_size=19, num_threads=1, capacity=5)
    
    with tf.Session() as sess:
        coord = tf.train.Coordinator()      # 线程的协调器
        threads = tf.train.start_queue_runners(sess, coord)     # 开始在图表中收集队列运行器
        for j in range(8):
            out = sess.run([label_batch])
            print(out)
        coord.request_stop()
        coord.join(threads)
    # [array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18])]
    # [array([19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37])]
    # [array([38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56])]
    # [array([57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75])]
    # [array([76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94])]
    # [array([95, 96, 97, 98, 99,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13])]
    # [array([14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32])]
    # [array([33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51])]

    shuffle_batch

    从队列中随机获取指定个数的元素组成一个 batch

    def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
                      num_threads=1, seed=None, enqueue_many=False, shapes=None,
                      allow_smaller_final_batch=False, shared_name=None, name=None):
      """Creates batches by randomly shuffling tensors."""

    capacity:队列容量,这个参数一定要比 min_after_dequeue 大

    推荐值为

    • capacit(min_after_dequeu(num_threada small safety margi∗ batcize

    min_after_dequeue当一次出列操作完成后,队列中元素的最小数量,往往用于定义元素的混合级别;

    定义了随机取样的缓冲区大小,此参数越大表示 更大级别的混合 但是 会导致启动更加缓慢,并且会占用更多的内存

    images = np.random.random([100,2])
    label = np.asarray(range(0, 100))
    images = tf.cast(images, tf.float32)
    label = tf.cast(label, tf.int32)
    input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
    image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=10, num_threads=1, capacity=64, min_after_dequeue=10)
    
    with tf.Session() as sess:
        coord = tf.train.Coordinator()      # 线程的协调器
        threads = tf.train.start_queue_runners(sess, coord)     # 开始在图表中收集队列运行器
        for _ in range(5):
            image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
            print(image_batch_v, label_batch_v)
    
        coord.request_stop()
        coord.join(threads)

    注意,这俩 API 是持续获取数据的,也就是说可以在循环中重复执行,每次获取不同数据

    参考资料:

    https://blog.csdn.net/akadiao/article/details/79645221

    https://blog.csdn.net/u013555719/article/details/77679964

  • 相关阅读:
    es6常见特性
    js实现查找字符串出现最多的字符和次数
    jQuery 插件封装的方法
    js变量作用域--变量提升
    js 三元表达式的写法
    bug
    基于bootstrap的模态框的comfirm弹窗
    基于bootstrap模态框的alert弹窗
    回车键搜索代码 兼容性
    盒子垂直居中方式
  • 原文地址:https://www.cnblogs.com/yanshw/p/12467753.html
Copyright © 2011-2022 走看看