zoukankan      html  css  js  c++  java
  • 什么是pytorch(4.数据集加载和处理)(翻译)

    数据集加载和处理

    这里主要涉及两个包:torchvision.datasets 和torch.utils.data.DatasetDataLoader

    torchvision.datasets是一些包装好的数据集

    里边所有可用的dataset都是 torch.utils.data.Dataset 的子类,这些子类都要有 __getitem__ __len__ 方法是实现。

    这样, 定义的数据集才能够被 torch.utils.data.DataLoader ,DataLoader能够使用torch.multiprocessing并行加载许多样本

    例如:

    imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
    data_loader = torch.utils.data.DataLoader(imagenet_data,
                                              batch_size=4,
                                              shuffle=True,
                                              num_workers=args.nThreads) 

    当我们需要使用我们的数据集的时候,就需要进行包装成DataLoader能够识别的Dataset这样就能把我们从无穷的数据预处理中解脱出来。
    创建数据集
    首先导入,创建一个子类:
    from torch.utils.data import Dataset
    import torch
    class MyDateset(Dataset):
        def __init__(self,num=10000,transform=None):  #这里就可以写你的参数了,比如文件夹什么的。
            self.len=num
            self.transform=transform
        def __len__(self):
            return self.len
        def __getitem__(self,idx):
            data=torch.rand(3,3,5)  #这里就是你的数据图像的话就是C*M*N的tensor,这里创建了一个3*3*5的张量
            label=torch.LongTensor([1])   #label也是需要一个张量
            if self.transform:    #这里就是数据预处理的部分 、
                data=self.transform(data)  #处理完必须要返回torch.Tensor类型
            return data,label

    下面我们测试一下:
    md=MyDateset()
    print(md[0])
    print(len(md))
    输出:
    (tensor([[[0.2753, 0.8114, 0.2916, 0.9600, 0.5057], [0.8595, 0.1195, 0.8065, 0.6393, 0.6213], 
    [0.0997, 0.8590, 0.2469, 0.2158, 0.5296]], [[0.4764, 0.0561, 0.5866, 0.6129, 0.1882],
    [0.4666, 0.9362, 0.5397, 0.3065, 0.4307], [0.4700, 0.6202, 0.3649, 0.6357, 0.5181]],
    [[0.9794, 0.8127, 0.9842, 0.8821, 0.2447], [0.2320, 0.6406, 0.5683, 0.5637, 0.2734],
    [0.2131, 0.5853, 0.5633, 0.9069, 0.9250]]]), tensor([1]))
    10000
    输出:这样我们就自定义了一个数据集Dataset,这样我们需要使用已有的数据集的时候就可以知道torchvision.dataset下许多数据集的构成了。
     
    预处理数据

    返回来再看上边定义数据集里有个参数transform,从定义getitem函数里看到,transform其实是一个函数。
    torchvision.transforms里就包括了好多的操作。当然它主要处理的是图像,就是C*H*W类型的举证了。
    可以直接这样使用:
    from torchvision import transforms

    md=MyDateset(transform=transforms.Normalize((0,0,0),(0.1,0.2,0.3)))
    print(md[0])
    (tensor([[[2.5435, 9.1073, 4.1653, 9.4720, 0.7595],
             [0.4840, 7.2377, 3.1578, 4.5391, 2.7440],
             [4.6951, 4.7698, 1.1308, 0.5321, 3.5101]],
    
            [[2.6714, 4.5143, 0.0582, 0.2880, 0.2565],
             [2.2951, 0.0680, 0.3542, 4.7372, 2.0162],
             [1.4065, 2.5195, 0.8911, 4.8432, 3.1045]],
    
            [[2.7726, 2.5199, 0.8066, 0.7089, 2.0651],
             [1.8641, 1.6599, 0.5546, 2.8716, 2.0964],
             [2.5320, 1.5349, 1.8792, 0.0933, 3.2289]]]), tensor([1]))
    更多的变换参见:https://pytorch.org/docs/master/torchvision/transforms.html

    当然我们也可以自定义一个函数传入:
    def add1(x):
        return x+1
    md=MyDateset(transform=add1)
    print(md[0])
    输出:
    (tensor([[[1.9552, 1.1294, 1.9435, 1.6476, 1.2726],
             [1.1544, 1.7726, 1.1975, 1.9914, 1.2694],
    当然也可以组合起来个transform形成一个一个处理级联:
    tc=transforms.Compose([transforms.Normalize((0,0,0),(0.1,0.2,0.3)),add1])
    md=MyDateset(transform=tc)
    print(md[0])

    输出:
    (tensor([[[ 1.9232,  6.4972,  7.9916,  4.3426, 10.9737],
             [ 5.4062,  2.6264,  6.8474,  4.7810,  3.3232],
             [ 8.6633,  4.1399,  2.3371,  5.5058,  3.9724]],
    等等。


    用Dataloader加载数据集

    在训练网络,测试网络时我们就需要使用刚才定义好的数据集了。

    from torch.utils.data import Dataset, DataLoader
    md=MyDateset()
    print(md[1])
    dl=DataLoader(md, batch_size=4,  shuffle=False,  num_workers=4)
    print(len(dl.dataset))

    这样dl就可以在程序里循环生成批样本,提供训练,测试了。



  • 相关阅读:
    [Windows] 一些简单的CMD命令
    开发过程中用到的触发器
    MyEclipse8.5配置struts等框架
    Java编程中中文乱码问题的研究及解决方案
    开源的SSH框架优缺点分析
    java 合并排序算法、冒泡排序算法、选择排序算法、插入排序算法、快速排序
    html,CSS文字大小单位px、em、pt的关系换算
    HTML常用标签参考学习
    匹配中文字符的正则表达式
    Oracle 取上周一到周末的sql
  • 原文地址:https://www.cnblogs.com/yjphhw/p/9811038.html
Copyright © 2011-2022 走看看