shapes = (tf.TensorShape([None, None]), tf.TensorShape([10, 10])) # 传入的是一个generator,即返回字段为yield的函数,不可传入嵌套生成器 # dataSet output_types参数必选,output_shapes参数可选,不选会直接适配数据的shape # 参数就是一个元组 data_set = tf.data.Dataset.from_generator(gen_epochs, output_types=(tf.int32, tf.int32), output_shapes=shapes, args=(n, batch_size, 10))
之前的一篇博文(https://blog.csdn.net/foreseerwang/article/details/80170210)介绍了使用Tensorflow Dataset进行数据导入的方法及其优势。最近在实际使用中越发感觉到这个方式非常好用,尤其是发现了.from_generator这个method。
关于Dataset.from_generator的简单介绍,请参见如下两个链接:
https://tensorflow.google.cn/versions/master/api_docs/python/tf/data/Dataset#repeat
https://blog.csdn.net/dqcfkyqdxym3f8rb0/article/details/79342369
注意,Dataset.from_generator在旧版Tensorflow中没有,起码在1.3版本tf.contrib.data.Dataset中还没有,后来用的1.7版本就有了。
我们知道,tensorflow的基本原理是先构造一个计算图,最后再统一计算。为此,tf重写了几乎所有常见函数,用于构造计算图,而且tensorflow不支持循环、选择等普通编程语言的常见操作。这就给编程使用带来比较大的麻烦。具体到data feeding上,也是如此。虽然设计了placeholder、train.slice_input_producer系列、Dataset等多种方式,但使用中仍有各种不便,尤其是在输入形式复杂、需要多重变换的时候更是如此。而Dataset.from_generator可以在一定程度上解决这个问题。
简单的说,Dataset.from_generator可以使用普通编程语言编写的外部子函数生成Dataset,这样几乎不受tensorflow编程不便的影响。先举一个最简单的示例:
''' import pickle fr=open('/media/dell/D/qcc/RandLA-Net/data/semantic_kitti/dataset/sequences_0.06/00/KDTree/000001.pkl','rb') inf = pickle.load(fr) doc = open('1.txt', 'a') print(inf, file=doc) print(inf) ''' # demo of Dataset.from_generator # blog.csdn.net/foreseerwang # QQ: 50834 """ Expected outputs: Batch No. 0: [0 1 2 3] Batch No. 1: [4 0 1 2] Batch No. 2: [3 4 0 1] Batch No. 3: [2 3 4] end! """ import numpy as np import tensorflow as tf def data_generator(): dataset = np.array(range(5)) for d in dataset: #print(d) yield d dataset = tf.data.Dataset.from_generator(data_generator, (tf.int32), (tf.TensorShape([]))) dataset = dataset.repeat(3) #3==epoch dataset = dataset.batch(4) #4==batchsize iterator = dataset.make_one_shot_iterator() one_element = iterator.get_next() with tf.Session() as sess: try: batch_num = 0 while True: one_batch = sess.run(one_element) print('Batch No. %d:' % batch_num) print(one_batch) print('') batch_num += 1 except tf.errors.OutOfRangeError: print('end!')
很显然,这个的输出如下:
-
Batch No. 0:
-
[0 1 2 3]
-
-
Batch No. 1:
-
[4 0 1 2]
-
-
Batch No. 2:
-
[3 4 0 1]
-
-
Batch No. 3:
-
[2 3 4]
-
-
end!
下面给出一个复杂的问题。假设需要输入如下序列:
A B
A C B
C
…
其中A/B/C分别代表一个文件,例如一张图片或是一个文本文件。每一行是一条记录,按行读入,并聚集多行形成batch,譬如每4行形成一个batch。这里有两个难点:1.每一行/每一条记录的元素长度不一样;2.读入元素A/B/C之后还要以之作为文件名读入文件内容。现有各种data feeding方式似乎很难同时解决这两个难点,除了Dataset.from_generator。
针对这个问题,使用Dataset.from_generator的一个简化版示例如下:
-
# demo of Dataset.from_generator
-
# blog.csdn.net/foreseerwang
-
# QQ: 50834
-
-
"""
-
Expected outputs:
-
-
Batch No. 0:
-
[[ 1 2 3]
-
[ 2 3 -1]]
-
-
Batch No. 1:
-
[[ 3 -1 -1]
-
[ 4 5 -1]]
-
-
Batch No. 2:
-
[[ 6 7 8]
-
[ 9 -1 -1]]
-
-
Batch No. 3:
-
[[10 11 12]
-
[13 14 -1]]
-
-
Batch No. 4:
-
[[15 -1 -1]]
-
-
end!
-
"""
-
-
import io
-
import numpy as np
-
import tensorflow as tf
-
-
class DataFeeder:
-
-
def __init__(self, filenames):
-
self.filenames = filenames
-
-
def file_readline(self):
-
for filename in self.filenames:
-
fr = io.open(filename, 'r', encoding='utf-8')
-
-
while True:
-
file_line = fr.readline()
-
if not file_line:
-
break
-
-
datalist = file_line.split()
-
# if datalist is a list of filename, file contents can
-
# be read and appendded here.
-
yield np.asarray(datalist, dtype='int32')
-
-
fr.close()
-
-
def generate_batch(self, batch_size, num_epochs=None):
-
dataset = tf.data.Dataset.from_generator(self.file_readline,
-
tf.int32,
-
tf.TensorShape([None]))
-
-
dataset = dataset.repeat(num_epochs)
-
dataset = dataset.padded_batch(
-
batch_size,
-
padded_shapes=tf.TensorShape([3]),
-
padding_values=-1)
-
-
iterator = dataset.make_one_shot_iterator()
-
out_batch = iterator.get_next()
-
-
return out_batch
-
-
filenames = ['a.txt', 'b.txt', 'c.txt']
-
data_feeder = DataFeeder(filenames)
-
one_batch = data_feeder.generate_batch(batch_size=2, num_epochs=1)
-
-
with tf.Session() as sess:
-
try:
-
batch_num = 0
-
while True:
-
data_batch = sess.run(one_batch)
-
print('Batch No. %d:' % batch_num)
-
print(data_batch)
-
print('')
-
batch_num+=1
-
-
except tf.errors.OutOfRangeError:
-
print('end!')
其中三个文本文件a.txt/b.txt/c.txt的内容分别如下:
a.txt:
1 2 3
2 3
3
b.txt:
4 5
6 7 8
9
c.txt:
10 11 12
13 14
15
运行以上代码的输出为:
-
Batch No. 0:
-
[[ 1 2 3]
-
[ 2 3 -1]]
-
-
Batch No. 1:
-
[[ 3 -1 -1]
-
[ 4 5 -1]]
-
-
Batch No. 2:
-
[[ 6 7 8]
-
[ 9 -1 -1]]
-
-
Batch No. 3:
-
[[10 11 12]
-
[13 14 -1]]
-
-
Batch No. 4:
-
[[15 -1 -1]]
-
-
end!
目前的输出,每个batch是batch_size * 3的矩阵。实际上,1~15的数字可以是某个图片的文件名,在file_readline()函数中读出这些数字后,可以继续读出这些文件的内容,并形成更高维度的Dataset输出,譬如:batch_size * img_size * img_size * img_channel的Dataset。
最后,说几点注意事项(详见代码):
1. generator函数不能有输入参数,但如果是class内的一个函数,可以使用self参数,这也是传递参数的一个手段;
2. 上述class中,建议传递文件名,在generator中打开处理再关闭,而不应该在外面打开(fr=open(filename, ‘r’)),然后把fr传递给generator读取。实践表明:后面这种方法形成的dataset不能repeat;
3. 因为序列不等长,在形成dataset batch时需要使用Dataset.padded_batch方法。