zoukankan      html  css  js  c++  java
  • pytorch学习---dataset

    1、dataset是初入pytorch最重要的东西,在复现项目的时候,最需要改的就是数据集。

     如果弄明白了pytorch中dataset类,你可以创建适应任意模型的数据集接口。

    2、所谓数据集,无非就是一组{x:y}的集合吗,你只需要在这个类里说明“有一组{x:y}的集合”就可以了。

    对于图像分类任务,图像+分类

    对于目标检测任务,图像+bbox、分类

    对于超分辨率任务,低分辨率图像+超分辨率图像

    对于文本分类任务,文本+分类

    ...

    你只需定义好这个项目的x和y是什么。好了,上面都是扯闲篇,我们直接看dataset代码:

    链接:https://blog.csdn.net/leviopku/article/details/99958182

    这个链接非常的详细。

    Pytorch用torch.utils.data.Dataset构建数据集,想要构建自己的数据集,则需继承Dataset类,并重写两个方法:

      • __len__ :定义整个数据集的长度。使用len(dataset)时会被调用。
      • __getitem__:用于索引数据集中的数据,比如dataset[i]

    Dataset基类
    PyTorch 读取图片,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。Dataset
    类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。
    看一下源码:

    这里有一个getitem函数,getitem函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。

    list的制作方法通常是将图片的路径和标签信息存储在一个txt中,然后从txt中读取,所以总结一下基本流程:

    制作存储了图片路径和标签信息的txt
    将这些信息转化成list,list的每一个元素对应一个样本
    通过getitem函数,读取数据和标签。
    其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,其实触发去读取图片这些操作的是DataLoader里的__iter__(self)(后面再将)。

    总而言之,要让PyTorch读取自己的数据集,只要两步:

    制作图片数据的索引
    构建Dataset子类
    制作图片数据索引
    非常简单,就是一些基本的操作,百度一下“”python如何保存txt文件“”就可以知道了。
    然后一般来说,txt都是这样的格式
    ./Data/train/01.png 0
    ./Data/train/02.png 0
    ./Data/train/03.png 1
    ./Data/train/04.png 1

    构建Dataset子类
    下面我们构建一下Dataset的子类,叫他MyDataset类:

    from PIL import Image
    from torch.utils.data import Dataset
    class MyDataset(Datset):
    def __init__(self,txt_path,transform=None,target_transform=None):
    fh = open(txt_path,'r')
    imgs = []
    for line in fh:
    line = line.rstrip()
    words = line.split()
    imgs.append((words[0].int(words[1])))
    self.imgs = imgs
    self.transform = transform
    def __getitem__(self,index):
    fn,label = self.imgs[index]
    img=Image.open(fn).convert('RGB')
    if self.transform is not None:
    img = self.transform(img)
    return img,label
    def __len__(self):
    return len(sefl.imgs)

    Init
    初始化中,我们从已经准备好的txt中获取了图片的路径和表亲啊,并且春初在self.imgs这意味着self.imgs是一个list就像上面我们讲的那样

    初始化中,初始化了transform,transform是一个Compose类型,transform中包含一个list,list中定义了各种对图像进行的操作,比如随机剪裁,旋转反转等。

    一个图片都进来之后,会经过数据处理(数据增强),最终变成另外一张图片,也就是模型的输入数据。但是PyTorch的数据增强是将原始图片进行处理,是不会生成新的图片。因此假如我们使用randomcrop这样的随机操作的时候,每次epoch输入进来的图片不会是一摸一样的,达到样本多样性的功能

    getitem
    self.imgs是一个list,每一个元素都是一个二元tuple,这很好理解(str1,str2)这样的
    利用Image.open对图片进行读取,img类型为Image,mode=‘RGB’
    用transform对图片进行处理,里面可能有什么 标准化(减均值除以标准差),随机剪裁什么的(后面会细说)
    这样Mydataset就构建好了,剩下的操作就交给DataLoader,在DataLoader中,会触发Mydataset中的getitem函数读取一张图片的数据和标签,并将多个图片拼接成一个batch返回,每一个batch才是模型真正的输入。

  • 相关阅读:
    大数据的分页优化的思路
    escape()、encodeURI()、encodeURIComponent()区别详解
    PHP面向对象知识总结
    mysql 简单优化规则
    mysql语句内部优化
    js onmouseout的冒泡事件
    Android 开机自启动
    查看 AndroidManifest.xml文件
    Hierarchy Viewer显示视图性能指标
    Profile GPU rendering
  • 原文地址:https://www.cnblogs.com/h694879357/p/15363919.html
Copyright © 2011-2022 走看看