代码
def data_iterator(tfrecords, batch_size=2, shuffle=True, train=True, num_parallel_reads=3):
# 声明TFRecordDataset
dataset = tf.data.TFRecordDataset(tfrecords, num_parallel_reads=num_parallel_reads)
dataset = dataset.map(_parse_function)
if shuffle:
dataset = dataset.shuffle(buffer_size=256)
if train:
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
return iterator
说明:
dataset.shuffle(buffer_size)
会在batch之间打乱,具体见前面的笔记num_parallel_reads
参数可以并行加载数据,实测可以在batch内部打乱数据。
如果数据制作的时候顺序固定,相似较大,比如按顺序crop的数据得到多个tfrecords,把这两项加上可以较为充分的打乱数据