zoukankan      html  css  js  c++  java
  • pytorch datasets与dataloader阐释说明

    一.torch.utils.data包含Dataset,Sampler,Dataloader

    torch.utils.data主要包括以下三个类:
    1. class torch.utils.data.Dataset: 作用: (1) 创建数据集,有__getitem__(self, index)函数来根据索引序号获取图片和标签, 有__len__(self)函数来获取数据集的长度.

    其他的数据集类必须是torch.utils.data.Dataset的子类,比如说torchvision.ImageFolder.

    2. class torch.utils.data.sampler.Sampler(data_source)
    参数: data_source (Dataset) – dataset to sample from

    作用: 创建一个采样器, class torch.utils.data.sampler.Sampler是所有的Sampler的基类, 其中,iter(self)函数来获取一个迭代器,对数据集中元素的索引进行迭代,len(self)方法返回迭代器中包含元素的长度.

    3. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

    二. datasets.ImageFolder  ,可用于提取分类网络图片使用

    参数:

    root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
    transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
    target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
    loader:表示数据集加载方式,通常默认加载方式即可。
    is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)

    属性值:

    • self.classes:用一个 list 保存类别名称
    • self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
    • self.imgs:保存(img-path, class) tuple的 list

      

    def verity_datasets():
    root = './datasets/train' # 根路径
    data = datasets.ImageFolder(root) # 可以理解载入dataset
    print('data.classes:',data.classes) # 类别信息
    print('data.class_to_idx:',data.class_to_idx) # 类别与索引
    print('data.imgs:',data.imgs) # 图片地址与标签
    img = cv2.imread(data.imgs[0][0])
    plt.imshow(img)
    plt.show()
    for img,label in data:
    image=cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR)
    print( image.shape,label)

    代码运行结果如下:

    若需要添加transform 可使用如下代码:

    from torchvision.datasets import ImageFolder
    from torchvision import transforms

    #加上transforms
    normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
    transform=transforms.Compose([
    transforms.RandomCrop(180),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
    normalize
    ])

    dataset=ImageFolder('./data/train',transform=transform)

    三.dataloader加载方式,需要添加自己信息如何更改源码如下:

    import numpy as np
    from PIL import Image
    from torch.utils.data.dataset import TensorDataset,Dataset
    from typing import TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple
    from torch.tensor import Tensor
    T_co = TypeVar('T_co', covariant=True)
    T = TypeVar('T')


    class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
    *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self,my_info, *tensors: Tensor) -> None:
    assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
    self.tensors = tensors
    self.my_info=my_info

    def __getitem__(self, index):
    return tuple([tensor[index],self.my_info[index]] for tensor in self.tensors)

    def __len__(self):
    return self.tensors[0].size(0)



    def verity_dataloader():


    x = torch.linspace(1, 10, 10)
    y = torch.linspace(10, 1, 10)
    k = [{'img_meta':20} for _ in range(10)]
    print(x,y)
    # 数据集包装数据和标签,实际是一个迭代器,类似dataset方法,一般为输入图片x与对应标签y,
    # 但如果想更改传入更多参数,需要自己更改源码,主要是__getiterm__方法。
    # torch_dataset = torch.utils.data.TensorDataset(x, y) # 未更改源码
    torch_dataset = TensorDataset(k,x,y) # 已经更改了源码

    loader = torch.utils.data.DataLoader(
    # 从数据库中每次抽出batch size个样本
    dataset=torch_dataset,
    batch_size=3,
    shuffle=True,
    num_workers=2,
    drop_last=True # True丢弃最后bath不足数据,false不丢弃
    )

    for step, (batch_x, batch_y) in enumerate(loader):
    # training
    print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))

    结果如下:

    参考博客:

    https://blog.csdn.net/qq_39507748/article/details/105394808

    https://blog.csdn.net/tsq292978891/article/details/79414512

    处理算法通用的辅助的code,如读取txt文件,读取xml文件,将xml文件转换成txt文件,读取json文件等
  • 相关阅读:
    python面向对象(一)
    ls和cd命令详解
    SHELL 中的变量
    Shell基础
    Python版飞机大战
    Python模块制作
    Linux的cut命令
    Linux中的wc命令
    Ubuntu系统下adb devices 不能显示手机设备
    app耗电量测试工具--PowerTutor
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/14856214.html
Copyright © 2011-2022 走看看