1、dataset是初入pytorch最重要的东西,在复现项目的时候,最需要改的就是数据集。
如果弄明白了pytorch中dataset类,你可以创建适应任意模型的数据集接口。
2、所谓数据集,无非就是一组{x:y}的集合吗,你只需要在这个类里说明“有一组{x:y}的集合”就可以了。
对于图像分类任务,图像+分类
对于目标检测任务,图像+bbox、分类
对于超分辨率任务,低分辨率图像+超分辨率图像
对于文本分类任务,文本+分类
...
你只需定义好这个项目的x和y是什么。好了,上面都是扯闲篇,我们直接看dataset代码:
链接:https://blog.csdn.net/leviopku/article/details/99958182
这个链接非常的详细。
Pytorch用torch.utils.data.Dataset
构建数据集,想要构建自己的数据集,则需继承Dataset
类,并重写两个方法:
__len__
:定义整个数据集的长度。使用len(dataset)
时会被调用。__getitem__
:用于索引数据集中的数据,比如dataset[i]
。
Dataset基类
PyTorch 读取图片,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。Dataset
类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。
看一下源码:
这里有一个getitem函数,getitem函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。
list的制作方法通常是将图片的路径和标签信息存储在一个txt中,然后从txt中读取,所以总结一下基本流程:
制作存储了图片路径和标签信息的txt
将这些信息转化成list,list的每一个元素对应一个样本
通过getitem函数,读取数据和标签。
其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,其实触发去读取图片这些操作的是DataLoader里的__iter__(self)(后面再将)。
总而言之,要让PyTorch读取自己的数据集,只要两步:
制作图片数据的索引
构建Dataset子类
制作图片数据索引
非常简单,就是一些基本的操作,百度一下“”python如何保存txt文件“”就可以知道了。
然后一般来说,txt都是这样的格式
./Data/train/01.png 0
./Data/train/02.png 0
./Data/train/03.png 1
./Data/train/04.png 1
构建Dataset子类
下面我们构建一下Dataset的子类,叫他MyDataset类:
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Datset):
def __init__(self,txt_path,transform=None,target_transform=None):
fh = open(txt_path,'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0].int(words[1])))
self.imgs = imgs
self.transform = transform
def __getitem__(self,index):
fn,label = self.imgs[index]
img=Image.open(fn).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(sefl.imgs)
Init
初始化中,我们从已经准备好的txt中获取了图片的路径和表亲啊,并且春初在self.imgs这意味着self.imgs是一个list就像上面我们讲的那样
初始化中,初始化了transform,transform是一个Compose类型,transform中包含一个list,list中定义了各种对图像进行的操作,比如随机剪裁,旋转反转等。
一个图片都进来之后,会经过数据处理(数据增强),最终变成另外一张图片,也就是模型的输入数据。但是PyTorch的数据增强是将原始图片进行处理,是不会生成新的图片。因此假如我们使用randomcrop这样的随机操作的时候,每次epoch输入进来的图片不会是一摸一样的,达到样本多样性的功能
getitem
self.imgs是一个list,每一个元素都是一个二元tuple,这很好理解(str1,str2)这样的
利用Image.open对图片进行读取,img类型为Image,mode=‘RGB’
用transform对图片进行处理,里面可能有什么 标准化(减均值除以标准差),随机剪裁什么的(后面会细说)
这样Mydataset就构建好了,剩下的操作就交给DataLoader,在DataLoader中,会触发Mydataset中的getitem函数读取一张图片的数据和标签,并将多个图片拼接成一个batch返回,每一个batch才是模型真正的输入。