DataLoader
torch.utils.data.Dataset
torch.utils.data.DataLoader
Dataset和DataLoader两个工具类完成数据的加载,
Dataset 用于构造数据集(数据集能够通过索引取出一条数据)、
DataLoader 用于取一批次的数据(Pytorch只支持批数据处理)
We use transforms to perform some manipulation of the data and make it suitable for training.
have two parameters - transform to modify the features
and target_transform to modify the labels
ToTensor converts a PIL image or NumPy ndarray into a FloatTensor
Lambda transforms apply any user-defined lambda function
1.Dataset
two different types of datasets:
01.map-style datasets, CLASS torch.utils.data.Dataset(*args, **kwds)
02.iterable-style datasets. CLASS torch.utils.data.IterableDataset(*args, **kwds)
2.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
DataLoader类
dataset 是定义的数据加载类的对象
batch_size 是每批次数据的大小,通常根据内存等确
shuffle 是每次加载一批数据时是否将其打乱,在训练时一般设置为True、测试时设置为False
num_workers是在读取数据时使用的线程数
collate_fn argument is used to collate lists of samples into batches.
When automatic batching is disabled, collate_fn is
When automatic batching is enabled, collate_fn
collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能- 将一个list的sample组成一个mini-batch 的函数
collate_fn (callable, optional): merges a list of samples to form a mini-batch.
mini-batch_size 的大小来指定一次加载至显存中的图片数量
返回值
- dataloader 封装成一个Batch Size大小的Tensor,用于后面的训练
- dataloader本质是一个可迭代对象,使用iter()访问
- 使用iter(dataloader)返回的是一个迭代器,然后可以使用next访问
类似的情况
torchvision.datasets
torchvision.transforms
torchvision.models
Python 语法复习
01.迭代的对象Iterable-可以用for循环的对象):
一类:list,tuple,dict,set,str (list,dict,str是Iterable,但不是Iterator,
要把list,dict,str等Iterable转换为Iterator可以使用iter()函数)
二类:generator,包含生成器和带yield的generatoe function
而生成器不但可以作用于for,还可以被next()函数不断调用并返回下一个值,
可以被next()函数不断返回下一个值的对象称为迭代器:Iterator
next(iterator[, default]) -- default -- 可选,用于设置在没有下一个元素时返回该默认值,如果不设置,又没有下一个元素则会触发 StopIteration 异常。
02.zip() 函数用于将可迭代的对象作为参数,
将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
,利用 * 号操作符,可以将元组解压为列表
如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同
03.#数据类型转换
# 将numpy数组转化为torch中的tensor:torch.from_numpy();
import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
# # 将torch的tensor转换为numpy数组:.numpy();; tensor --> numpy --> image
c = .numpy()
# numpy数组和图片互转 import cv2
# .npy是numpy 专用的二进制文件
import numpy as np
import cv2
# 默认的dtype是float64
arrayexp = np.array([ [[14,7,0],[15,9,0],[48, 40, 27],[ 0,0,0],[ 1,2,0],[ 2,3,1]]
,[[51, 45 ,34],[42, 36, 23],[42, 37, 22],[ 0,0,0],[ 1,2,0],[ 1,2,0]]
,[[42, 39 ,24],[36, 34, 16],[32, 31, 11],[ 1,1,1],[ 1,1,1],[ 1,1,1]]
,[[ 1,2,0],[ 3,4,0],[ 3,4,0],[ 1,1,1],[ 2,3,1],[ 1,2,0]]
,[[ 1,2,0],[ 2,3,0],[ 2,3,0],[ 1,1,1],[ 1,2,0],[ 1,2,0]]
,[[ 2,3,0],[ 2,3,0],[ 3,4,0],[ 2,3,1],[ 1,2,0],[ 1,2,0]]
],dtype=np.uint8)
print(arrayexp.shape,arrayexp.dtype)
# 1.起图片名,2.图片本身
cv2.namedWindow('trans2')
cv2.imshow('trans2',arrayexp)
cv2.waitKey(0)
print("done !")
# 注意 opencv的像素是BGR顺序,然而matplotlib所遵循的是RGB顺序。 opencv的一个像素为:[B,G,R] ,matplotlib的一个像素为:[R,G,B]
参考
TORCH.UTILS.DATA https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset
python中的next()以及iter()函数 https://www.cnblogs.com/SupremeBoy/p/12251240.html
numpy数组和图片互转 https://blog.csdn.net/sin_404/article/details/115397142