zoukankan      html  css  js  c++  java
  • Dataset 和 DataLoader 详解

    Dataset 是 PyTorch 中用来表示数据集的一个抽象类,我们的数据集可以用这个类来表示,至少需要覆写下面两个方法:

        1)__len__:一般用来返回数据集大小。

        2)__getitem__:实现这个方法后,可以通过下标的方式 dataset[i] 的来取得第 $i$ 个数据。

    DataLoader 本质上就是一个 iterable(内部定义了 __iter__ 方法),__iter__ 被定义成生成器,使用 yield 来返回数据,

    并利用多进程来加速 batch data 的处理,DataLoader 组装好数据后返回的是 Tensor 类型的数据。

    注意:DataLoader 是间接通过 Dataset 来获得数据的,然后进行组装成一个 batch 返回,因为采用了生成器,所以每次只会组装

    一个 batch 返回,不会一次性组装好全部的 batch,所以 DataLoader 节省的是 batch 的内存,并不是指数据集的内存,数据集可

    以一开始就全部加载到内存里,也可以分批加载,这取决于 Dataset 中 __init__ 函数的实现。

    举个例子:

    import torch
    import numpy as np
    from torch.utils.data import Dataset
    from torch.utils.data import DataLoader
    
    class DiabetesDataset(Dataset):
        def __init__(self, filepath):
            # 因为数据集比较小,所以全部加载到内存里了
            data = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
            self.len = data.shape[0]
            self.x_data = torch.from_numpy(data[:,:-1])
            self.y_data = torch.from_numpy(data[:,[-1]])
    
        def __getitem__(self, index):
            return self.x_data[index], self.y_data[index]
    
        def __len__(self):
            return self.len
    
    dataset = DiabetesDataset('diabetes.csv.gz')
    train_loader = DataLoader(dataset=dataset,   # 传递数据集
                              batch_size=32,     # 小批量的数据大小,每次加载一batch数据
                              shuffle=True,      # 打乱数据之间的顺序
                              num_workers=2)     # 使用多少个子进程来加载数据,默认为0, 代表使用主线程加载batch数据
    
    for epoch in range(100):  # 训练 100 轮
        for i, data in enumerate(train_loader, 0):  # 每次惰性返回一个 batch 数据
            iuputs, label = data
            ...
    
  • 相关阅读:
    图的最大匹配算法
    二分图的最小顶点覆盖 最大独立集 最大团
    后缀数组:倍增法和DC3的简单理解
    后缀自动机浅析
    微积分学习笔记一:极限 导数 微分
    微积分学习笔记二
    微积分学习笔记三:定积分
    微积分学习笔记四:空间向量基础
    微积分学习笔记五:多元函数微积分
    程序员之路--回顾2015,展望2016
  • 原文地址:https://www.cnblogs.com/yanghh/p/14074744.html
Copyright © 2011-2022 走看看