zoukankan      html  css  js  c++  java
  • TensorFlow dataset.shuffle、batch、repeat的使用详解

    https://www.jb51.net/article/178976.htm

    直接看代码例子,有详细注释!!

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    import tensorflow as tf
    import numpy as np
     
     
    d = np.arange(0,60).reshape([6, 10])
     
    # 将array转化为tensor
    data = tf.data.Dataset.from_tensor_slices(d)
     
    # 从data数据集中按顺序抽取buffer_size个样本放在buffer中,然后打乱buffer中的样本
    # buffer中样本个数不足buffer_size,继续从data数据集中安顺序填充至buffer_size,
    # 此时会再次打乱
    data = data.shuffle(buffer_size=3)
     
    # 每次从buffer中抽取4个样本
    data = data.batch(4)
     
    # 将data数据集重复,其实就是2个epoch数据集
    data = data.repeat(2)
     
    # 构造获取数据的迭代器
    iters = data.make_one_shot_iterator()
     
    # 每次从迭代器中获取一批数据
    batch = iters.get_next()
     
    sess = tf.Session()
     
    sess.run(batch)
    # 数据集完成遍历完之后,继续抽取的话会报错:OutOfRangeError
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    In [21]: d
    Out[21]:
    array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
      [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
      [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
      [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
      [40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
      [50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])
    In [22]: sess.run(batch)
    Out[22]:
    array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
      [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
      [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
      [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]])
     
    In [23]: sess.run(batch)
    Out[23]:
    array([[40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
      [50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])

    从输出结果可以看出:

    shuffle是按顺序将数据放入buffer里面的;

    当repeat函数在shuffle之后的话,是将一个epoch的数据集抽取完毕,再进行下一个epoch的。

    那么,当repeat函数在shuffle之前会怎么样呢?如下:

    1
    2
    3
    4
    5
    data = data.repeat(2)
     
    data = data.shuffle(buffer_size=3)
     
    data = data.batch(4)
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    In [25]: sess.run(batch)
    Out[25]:
    array([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
      [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
      [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
      [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])
     
    In [26]: sess.run(batch)
    Out[26]:
    array([[50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
      [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
      [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
      [30, 31, 32, 33, 34, 35, 36, 37, 38, 39]])
     
    In [27]: sess.run(batch)
    Out[27]:
    array([[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
      [50, 51, 52, 53, 54, 55, 56, 57, 58, 59],
      [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
      [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])

    可以看出,其实它就是先将数据集复制一遍,然后把两个epoch当成同一个新的数据集,一直shuffle和batch下去。

  • 相关阅读:
    第40次全国计算机等级考试监考
    [再寄小读者之数学篇](2014-07-27 打印错误吧)
    日积月累的名典[2014-10-7]
    2014年全球“高被引科学家”数学类名单
    年轻尼姑的19句话
    PostgreSQL的 initdb 源代码分析之十六
    PostgreSQL的 initdb 源代码分析之十五
    PostgreSQL的 initdb 源代码分析之十四
    PostgreSQL的 initdb 源代码分析之十三
    PostgreSQL的 initdb 源代码分析之十二
  • 原文地址:https://www.cnblogs.com/yibeimingyue/p/13869479.html
Copyright © 2011-2022 走看看