zoukankan      html  css  js  c++  java
  • Tensorflow Dataset.from_generator使用示例

    shapes = (tf.TensorShape([None, None]), tf.TensorShape([10, 10]))
    # 传入的是一个generator,即返回字段为yield的函数,不可传入嵌套生成器
    # dataSet output_types参数必选,output_shapes参数可选,不选会直接适配数据的shape
    # 参数就是一个元组
    data_set = tf.data.Dataset.from_generator(gen_epochs,
                                              output_types=(tf.int32, tf.int32),
                                              output_shapes=shapes,
                                              args=(n, batch_size, 10))

    之前的一篇博文(https://blog.csdn.net/foreseerwang/article/details/80170210)介绍了使用Tensorflow Dataset进行数据导入的方法及其优势。最近在实际使用中越发感觉到这个方式非常好用,尤其是发现了.from_generator这个method。

    关于Dataset.from_generator的简单介绍,请参见如下两个链接:

    https://tensorflow.google.cn/versions/master/api_docs/python/tf/data/Dataset#repeat

    https://blog.csdn.net/dqcfkyqdxym3f8rb0/article/details/79342369

    注意,Dataset.from_generator在旧版Tensorflow中没有,起码在1.3版本tf.contrib.data.Dataset中还没有,后来用的1.7版本就有了。

    我们知道,tensorflow的基本原理是先构造一个计算图,最后再统一计算。为此,tf重写了几乎所有常见函数,用于构造计算图,而且tensorflow不支持循环、选择等普通编程语言的常见操作。这就给编程使用带来比较大的麻烦。具体到data feeding上,也是如此。虽然设计了placeholder、train.slice_input_producer系列、Dataset等多种方式,但使用中仍有各种不便,尤其是在输入形式复杂、需要多重变换的时候更是如此。而Dataset.from_generator可以在一定程度上解决这个问题。

    简单的说,Dataset.from_generator可以使用普通编程语言编写的外部子函数生成Dataset,这样几乎不受tensorflow编程不便的影响。先举一个最简单的示例:

    '''
    import pickle
    fr=open('/media/dell/D/qcc/RandLA-Net/data/semantic_kitti/dataset/sequences_0.06/00/KDTree/000001.pkl','rb')
    inf = pickle.load(fr)
    doc = open('1.txt', 'a')
    print(inf, file=doc)
    print(inf)
    '''
    
    # demo of Dataset.from_generator
    # blog.csdn.net/foreseerwang
    # QQ: 50834
    
    """
    Expected outputs:
    Batch No. 0:
    [0 1 2 3]
    Batch No. 1:
    [4 0 1 2]
    Batch No. 2:
    [3 4 0 1]
    Batch No. 3:
    [2 3 4]
    end!
    """
    
    import numpy as np
    import tensorflow as tf
    
    
    def data_generator():
        dataset = np.array(range(5))
        for d in dataset:
            #print(d)
            yield d
    
    
    dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32), (tf.TensorShape([])))
    dataset = dataset.repeat(3) #3==epoch
    dataset = dataset.batch(4) #4==batchsize
    
    iterator = dataset.make_one_shot_iterator()
    one_element = iterator.get_next()
    
    with tf.Session() as sess:
        try:
            batch_num = 0
            while True:
                one_batch = sess.run(one_element)
                print('Batch No. %d:' % batch_num)
                print(one_batch)
                print('')
                batch_num += 1
    
        except tf.errors.OutOfRangeError:
            print('end!')
            
            

    很显然,这个的输出如下:

    1. Batch No. 0:
    2. [0 1 2 3]
    3.  
    4. Batch No. 1:
    5. [4 0 1 2]
    6.  
    7. Batch No. 2:
    8. [3 4 0 1]
    9.  
    10. Batch No. 3:
    11. [2 3 4]
    12.  
    13. end!

    下面给出一个复杂的问题。假设需要输入如下序列:

    A B

    A C B

    C

    其中A/B/C分别代表一个文件,例如一张图片或是一个文本文件。每一行是一条记录,按行读入,并聚集多行形成batch,譬如每4行形成一个batch。这里有两个难点:1.每一行/每一条记录的元素长度不一样;2.读入元素A/B/C之后还要以之作为文件名读入文件内容。现有各种data feeding方式似乎很难同时解决这两个难点,除了Dataset.from_generator。

    针对这个问题,使用Dataset.from_generator的一个简化版示例如下:

    1. # demo of Dataset.from_generator
    2. # blog.csdn.net/foreseerwang
    3. # QQ: 50834
    4.  
    5. """
    6. Expected outputs:
    7.  
    8. Batch No. 0:
    9. [[ 1 2 3]
    10. [ 2 3 -1]]
    11.  
    12. Batch No. 1:
    13. [[ 3 -1 -1]
    14. [ 4 5 -1]]
    15.  
    16. Batch No. 2:
    17. [[ 6 7 8]
    18. [ 9 -1 -1]]
    19.  
    20. Batch No. 3:
    21. [[10 11 12]
    22. [13 14 -1]]
    23.  
    24. Batch No. 4:
    25. [[15 -1 -1]]
    26.  
    27. end!
    28. """
    29.  
    30. import io
    31. import numpy as np
    32. import tensorflow as tf
    33.  
    34. class DataFeeder:
    35.  
    36. def __init__(self, filenames):
    37. self.filenames = filenames
    38.  
    39. def file_readline(self):
    40. for filename in self.filenames:
    41. fr = io.open(filename, 'r', encoding='utf-8')
    42.  
    43. while True:
    44. file_line = fr.readline()
    45. if not file_line:
    46. break
    47.  
    48. datalist = file_line.split()
    49. # if datalist is a list of filename, file contents can
    50. # be read and appendded here.
    51. yield np.asarray(datalist, dtype='int32')
    52.  
    53. fr.close()
    54.  
    55. def generate_batch(self, batch_size, num_epochs=None):
    56. dataset = tf.data.Dataset.from_generator(self.file_readline,
    57. tf.int32,
    58. tf.TensorShape([None]))
    59.  
    60. dataset = dataset.repeat(num_epochs)
    61. dataset = dataset.padded_batch(
    62. batch_size,
    63. padded_shapes=tf.TensorShape([3]),
    64. padding_values=-1)
    65.  
    66. iterator = dataset.make_one_shot_iterator()
    67. out_batch = iterator.get_next()
    68.  
    69. return out_batch
    70.  
    71. filenames = ['a.txt', 'b.txt', 'c.txt']
    72. data_feeder = DataFeeder(filenames)
    73. one_batch = data_feeder.generate_batch(batch_size=2, num_epochs=1)
    74.  
    75. with tf.Session() as sess:
    76. try:
    77. batch_num = 0
    78. while True:
    79. data_batch = sess.run(one_batch)
    80. print('Batch No. %d:' % batch_num)
    81. print(data_batch)
    82. print('')
    83. batch_num+=1
    84.  
    85. except tf.errors.OutOfRangeError:
    86. print('end!')

    其中三个文本文件a.txt/b.txt/c.txt的内容分别如下:

    a.txt:

    1 2 3
    2 3
    3

    b.txt:

    4 5
    6 7 8
    9

    c.txt:

    10 11 12
    13 14
    15

    运行以上代码的输出为:

    1. Batch No. 0:
    2. [[ 1 2 3]
    3. [ 2 3 -1]]
    4.  
    5. Batch No. 1:
    6. [[ 3 -1 -1]
    7. [ 4 5 -1]]
    8.  
    9. Batch No. 2:
    10. [[ 6 7 8]
    11. [ 9 -1 -1]]
    12.  
    13. Batch No. 3:
    14. [[10 11 12]
    15. [13 14 -1]]
    16.  
    17. Batch No. 4:
    18. [[15 -1 -1]]
    19.  
    20. end!

    目前的输出,每个batch是batch_size * 3的矩阵。实际上,1~15的数字可以是某个图片的文件名,在file_readline()函数中读出这些数字后,可以继续读出这些文件的内容,并形成更高维度的Dataset输出,譬如:batch_size * img_size * img_size * img_channel的Dataset。

    最后,说几点注意事项(详见代码):

    1. generator函数不能有输入参数,但如果是class内的一个函数,可以使用self参数,这也是传递参数的一个手段;

    2. 上述class中,建议传递文件名,在generator中打开处理再关闭,而不应该在外面打开(fr=open(filename, ‘r’)),然后把fr传递给generator读取。实践表明:后面这种方法形成的dataset不能repeat;

    3. 因为序列不等长,在形成dataset batch时需要使用Dataset.padded_batch方法。

  • 相关阅读:
    VB.net byval和byref
    IOS 常用宏定义(二)
    目录权限Linux存储实验四:NFS的安装与配置
    服务器负载均衡nginx+keepalived 高可用负载均衡
    博客文件第二部分 Linux Shell高级编程技巧——第一章 深入讨论
    查看进程第二部分 Linux Shell高级编程技巧——第二章 Shell工具
    响应命令使用Ubuntu架设ftp服务器
    进程内存Linux下top命令
    客户端服务器SSH原理和使用
    安装数据空间虚拟CentOS访问Windows下共享文件
  • 原文地址:https://www.cnblogs.com/yibeimingyue/p/13805105.html
Copyright © 2011-2022 走看看