zoukankan      html  css  js  c++  java
  • pytorch.utils.data

    概览

    torch.utils.data主要是负责容纳数据集、数据打散、分批等操作。

    这里面有三个概念:数据集dataset,抽样器sampler,数据加载器dataloader。其中第三个就是最终对外的接口,也是最重要的。

    它们之间的关系是:首先需要根据源数据创建数据集dataset,然后根据dataset创建抽样器sampler,最后同时通过dataset和sampler来创建dataloader,这就是我们最终需要的。这个在训练、测试的时候,会得到batch数据。

    dataset

    第一个是dataset,就是常规理解的数据集。

    数据集主要分为两种:map-style和iterable-style

    map-style数据集,一般都是继承Dataset类 ,必须要实现__getitem__()__len__()方法,表示从索引或者key到数据样本的映射

    iterable-style数据集,一般都是继承IterableDataset类,必须实现__iter__()方法,表示在数据样本上迭代。一般从一些流中实时获取数据(比如数据库、远程服务器或者日志),是无法进行随机读取的,这时就主要使用迭代式数据集。

    一般如果数据量小,使用map-style就可以了,如果数据量很大,需要从数据流中获取,那就使用iterable-style

    对应到具体的类,有以下六个:

    • torch.utils.data.Dataset
    • torch.utils.data.IterableDataset
    • torch.utils.data.TensorDataset
    • torch.utils.data.ConcatDataset
    • torch.utils.data.ChainDataset
    • torch.utils.data.Subset

    除此之外,torch.utils.data还包含了两个函数

    • torch.utils.data.get_worker_info()
    • torch.utils.data.random_split()

    sampler

    sampler是抽样器,作用在dataset上面

    抽样的方式也有几个方式:

    按顺序抽样,随机抽样,在子集合中随机抽样,带权重的抽样等等

    包括以下类:

    • class Sampler
    • class SequentialSampler
    • class RandomSampler
    • class SubsetRandomSampler
    • class WeightedRandomSampler
    • class BatchSampler
    • class DistributedSampler

    生成sampler的最终目的就是为了创建dataloader。

    dataLoader

    DataLoader是核心。

    DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
               batch_sampler=None, num_workers=0, collate_fn=None,
               pin_memory=False, drop_last=False, timeout=0,
               worker_init_fn=None)
    

    构建DataLoader有几个重要的参数:

    • dataset是数据集,
    • batch_size
    • shuffle 是否每一轮都将数据进行打散,最好通过sampler来打散,否则使用SequentialSampler的时候也会被打散。
    • sampler 生成indices
    • collate_fn
    • pin_memory 含义参考pytorch pinned memory

    实例1:通过TensorDataset快速生成dataloader

    数据中有字符串类型的时候慎用。

    import torch
    from torch.utils.data import DataLoader, TensorDataset, Dataset, RandomSampler
    import numpy as np
    
    
    # 创建TensorDataset
    feature = torch.tensor(np.arange(100))
    dataset = TensorDataset([feature, feature])
    sampler = RandomSampler(dataset)
    dataloader = DataLoader(dataset, batch_size=5, sampler=sampler)
    
    for epoch in range(2):
        print('epoch=', epoch)
        for index, batch in enumerate(dataloader):
            print(batch)
            if index > 10:
                break
    
    epoch= 0
    [tensor([79,  6, 81, 35, 21], dtype=torch.int32), tensor([79,  6, 81, 35, 21], dtype=torch.int32)]
    [tensor([43, 98, 86, 23, 68], dtype=torch.int32), tensor([43, 98, 86, 23, 68], dtype=torch.int32)]
    [tensor([ 0, 36, 60,  1, 91], dtype=torch.int32), tensor([ 0, 36, 60,  1, 91], dtype=torch.int32)]
    [tensor([71, 59, 72, 75, 52], dtype=torch.int32), tensor([71, 59, 72, 75, 52], dtype=torch.int32)]
    [tensor([45,  2, 73, 46, 95], dtype=torch.int32), tensor([45,  2, 73, 46, 95], dtype=torch.int32)]
    [tensor([82, 37, 24, 12, 16], dtype=torch.int32), tensor([82, 37, 24, 12, 16], dtype=torch.int32)]
    [tensor([90, 11, 70, 31, 53], dtype=torch.int32), tensor([90, 11, 70, 31, 53], dtype=torch.int32)]
    [tensor([15,  7, 64, 22, 65], dtype=torch.int32), tensor([15,  7, 64, 22, 65], dtype=torch.int32)]
    [tensor([ 3, 87,  4, 17, 99], dtype=torch.int32), tensor([ 3, 87,  4, 17, 99], dtype=torch.int32)]
    [tensor([83, 20, 19, 89, 42], dtype=torch.int32), tensor([83, 20, 19, 89, 42], dtype=torch.int32)]
    [tensor([97, 58,  8, 38, 30], dtype=torch.int32), tensor([97, 58,  8, 38, 30], dtype=torch.int32)]
    [tensor([54, 56, 48, 27, 57], dtype=torch.int32), tensor([54, 56, 48, 27, 57], dtype=torch.int32)]
    epoch= 1
    [tensor([66, 15, 37, 82, 47], dtype=torch.int32), tensor([66, 15, 37, 82, 47], dtype=torch.int32)]
    [tensor([75, 70,  5, 99, 33], dtype=torch.int32), tensor([75, 70,  5, 99, 33], dtype=torch.int32)]
    [tensor([80, 76, 55, 29, 41], dtype=torch.int32), tensor([80, 76, 55, 29, 41], dtype=torch.int32)]
    [tensor([79, 17, 63, 92, 74], dtype=torch.int32), tensor([79, 17, 63, 92, 74], dtype=torch.int32)]
    [tensor([52, 53, 58, 38, 87], dtype=torch.int32), tensor([52, 53, 58, 38, 87], dtype=torch.int32)]
    [tensor([84, 59, 77, 48, 71], dtype=torch.int32), tensor([84, 59, 77, 48, 71], dtype=torch.int32)]
    [tensor([56, 16, 27, 81, 60], dtype=torch.int32), tensor([56, 16, 27, 81, 60], dtype=torch.int32)]
    [tensor([50, 73, 46, 28, 32], dtype=torch.int32), tensor([50, 73, 46, 28, 32], dtype=torch.int32)]
    [tensor([45, 40, 10, 25,  9], dtype=torch.int32), tensor([45, 40, 10, 25,  9], dtype=torch.int32)]
    [tensor([12, 49, 22, 51, 20], dtype=torch.int32), tensor([12, 49, 22, 51, 20], dtype=torch.int32)]
    [tensor([ 6, 68, 72, 24, 67], dtype=torch.int32), tensor([ 6, 68, 72, 24, 67], dtype=torch.int32)]
    [tensor([57, 96, 23, 97, 98], dtype=torch.int32), tensor([57, 96, 23, 97, 98], dtype=torch.int32)]
    

    自定义Dataset

    import torch
    import torch.nn as nn
    import numpy as np
    import random
    from torch.utils.data import Dataset, DataLoader
    
    class ToyDataset(Dataset):
        def __init__(self):
            self.Data = np.arange(32).reshape(16, 2).tolist()
            self.Target = np.random.randint(0, 2, (16,1)).tolist()
    
        def __getitem__(self, index):
            txt = torch.LongTensor(self.Data[index])
            label = torch.LongTensor(self.Target[index])
            return txt, label
        
        def __len__(self):
            return len(self.Data)
    
  • 相关阅读:
    jquery获得option的值和对option进行操作
    laravel 在添加操作自动完成对时间保存修改
    laravel使用ajax
    mysql操作查询结果case when then else end用法举例
    Laravel框架数据库CURD操作、连贯操作总结
    laravel5.1关于lists函数的bug
    详解AngularJS中的filter过滤器用法
    javascript中的时间处理
    angularJs--$on、$emit和$broadcast的使用
    angularJs--<ui-select>
  • 原文地址:https://www.cnblogs.com/YoungF/p/13941346.html
Copyright © 2011-2022 走看看