3.1 tf.data模块简介及用法
可以从多个数据源非常方便的读取数据
在第一个epoch中缓存下载,接下来的epoch处理的速度就会变快
使用简单的代码构建复杂的输入,可以轻松处理大量数据、不同数据格式以及复杂的转换
tf.data.Dataset 表示一系列的元素、图片,每个元素包含一个或多个Tensor对象,例如在图片通道中,一个元素可能是单个训练样本,具有一对表示图片数据和标签的张量。
创建tf.data.Dataset方式
方式一
直接从Tensor创建Dataset
例如Dataset.from_tensor_slices()
参数也可以是Numpy(tensorflow会自动将它转换成Tensor),Tensor,列表
方式二
通过对一个或多个tf.data.Dataset对象来使用变换(例如Dataset.zip)来创建Dataset
一个Dataset对象包含多个元素,每个元素的结构都相同。每个元素包含一个或多个tf.Tensor对象,这些对象被称为组件。Dataset的属性由构成该Dataset的元素的属性映射得到,元素可以是单个张量、张量元组,也可以是张量的嵌套元组。
使用列表创建
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5,6,7])
for ele in dataset:
print(ele,ele.numpy())
图3-1
可以使用迭代获取每个元素,并且可以使用.numpy()将其转换为array格式
二维列表创建
dataset = tf.data.Dataset.from_tensor_slices([[1,2],[3,4],[5,6]])
for ele in dataset:
print(ele,ele.numpy())
图3-2
其shape为(2,)说明,每个元素长度为2
使用字典创建
dataset = tf.data.Dataset.from_tensor_slices({'a':[1,2,3,4],
'b':[5,6,7,8],
'c':[10,11,12,13]})
for ele in dataset:
print(ele)
图3-3
take()方法
for ele in dataset.take(4): #只遍历前4项
print(ele)
训练网络的时候,每一轮训练的数据的顺序不一样,有助于网络的训练,否则可能会出现网络记忆了数据的顺序的可能。
dataset = dataset.shuffle(7) #打乱数据的顺序,里面数字的大小一般等于元素的个数
dataset = dataset.repeate(count = 3) #将dataset的数据重复三次,如果之前没有进行shuffle,那么会按照顺序重复三次,如果进行了shuffle,那么每一次的重复的数据都是随机排序的。不写参数,默认重复无数次
dataset = dataset.batch(4) 表示每4个元素作为一个数据绑定,也就是训练的时候,每次处理4个元素
可以使用map()对每个元素进行变换
dataset = dataset.map(tf.square) 获得的新数据为原先的平方