zoukankan      html  css  js  c++  java
  • PyTorch笔记之 Dataset 和 Dataloader

    简介

    在 PyTorch 中,我们的数据集往往会用一个类去表示,在训练时用 Dataloader 产生一个 batch 的数据

    https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

    比如官方例子中对 CIFAR10 图像数据集进行分类,就有用到这样的操作,具体代码如下所示

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                             shuffle=False, num_workers=2)

    简单说,用 一个类 抽象地表示数据集,而 Dataloader 作为迭代器,每次产生一个 batch 大小的数据,节省内存

    Dataset

    Dataset 是 PyTorch 中用来表示数据集的一个抽象类,我们的数据集可以用这个类来表示,至少覆写下面两个方法即可

    这返回数据前可以进行适当的数据处理,比如将原文用一串数字序列表示

    • __len__:数据集大小
    • __getitem__:实现这个方法后,可以通过下标的方式( dataset[i] )的来取得第 $i$ 个数据

    下面我们来为编写一个类表示一个情感二分类数据集,继续用苏神整理的数据集

    https://github.com/bojone/bert4keras/tree/master/examples/datasets

    数据集没有表头,只有2列,一列是评论(文本),另一列是标签,以制表符进行分隔

    from torch.utils.data import Dataset, DataLoader
    import pandas as pd
    
    class SentimentDataset(Dataset):
        def __init__(self, path_to_file):
            self.dataset = pd.read_csv(path_to_file, sep="	", names=["text", "label"])
        def __len__(self):
            return len(self.dataset)
        def __getitem__(self, idx):
            text = self.dataset.loc[idx, "text"]
            label = self.dataset.loc[idx, "label"]
            sample = {"text": text, "label": label}
            return sample

    Dataloader

    基本使用

    Dataloader 就是一个迭代器,最基本的使用就是传入一个 Dataset 对象,它就会根据参数 batch_size 的值生成一个 batch 的数据

    if __name__ == "__main__":
        sentiment_dataset = SentimentDataset("sentiment.test.data")
        sentiment_dataloader = DataLoader(sentiment_dataset, batch_size=4, shuffle=True, num_workers=2)
        for idx, batch_samples in enumerate(sentiment_dataloader):
            text_batchs, text_labels = batch_samples["text"], batch_samples["label"]
            print(text_batchs)

    Sampler

    PyTorch 提供了 Sampler 模块,用来对数据进行采样,可以在 DataLoader 的通过 sampler 参数调用

    一般我们的加载训练集的 dataloader ,shuffle参数都会设置为True ,这时候使用了一个默认的采样器——RandomSampler

    当 shuffle 设置为 False 时,默认使用的是 SequencetialSampler,其实就是按顺序取出数据集中的元素

    在 PyTorch 中默认实现了以下 Sampler,如果我们要使用别的 Sampler, shuffle 要设置为 False

    • SequentialSampler
    • RandomSampler
    • WeightedSampler
    • SubsetRandomSampler

    SubsetRandomSampler 常用来将数据集划分为训练集和测试集,比如这里就训练集和测试集按7:3 进行分割

    n_train = len(sentiment_train_set)
    split = n_train // 3
    
    indices
    = list(range(n_train)) train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
    train_loader
    = DataLoader(sentiment_train_set, sampler=train_sampler, shuffle=False) valid_loader = DataLoader(sentiment_train_set, sampler=valid_sampler, shuffle=False)

    具体推荐下面的博文,讲得挺详细的

    一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    https://www.cnblogs.com/marsggbo/p/11541054.html

    Pytorch Sampler详解

    https://www.cnblogs.com/marsggbo/p/11541054.html

    collate_fn

    可以用来进行一些数据处理,比如在文本任务中,一般由于文本长度不一致,我们需要进行截断或者填充。对于图片,我们则希望它们有同样的尺寸

    我么可以编写一个函数,然后用这个参数调用它,下面是一个简单的例子,我们把文本截断成只有10个字符

    def truncate(data_list):
      """传进一个batch_size大小的数据"""
      for data in data_list:
        text = data["text"]
        data["text"]=text[:10]
      return data_list
    
    test_loader = DataLoader(sentiment_train_set, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=truncate)

    我们可以看看返回的内容是否已经经过截断了

    for i in test_loader:
      print(i)
      break

    这时候返回的是一个列表而不是字典了,其中一个 batch 的返回结果如下,我们可以看到这里一个样本放在了一个字典中

    [{'text': '看了一个通宵,实在是', 'label': 1}, 。。。, {'text': '看了携程的其他用户评', 'label': 0}]

    下面是没有使用 collate_fn 的返回结果,它会将数据和标签分开,存放在一起,如下所示

    {

    'text':['3月1号订的,3月15号还没到货 客服每天说下个工作日能到货已经连续5天了 我无语。想早点儿看这本书的人还是去陶宝或卓越上订吧,尤其是广东省的朋友.当当送货太没保证了.',。。。, '非常纯朴的故事,但包含了主人公坎坷的一生,活着就是痛苦,不得不佩服生命的韧性'],

    'label': tensor([1, 。。。, 0])

    }

  • 相关阅读:
    editplus 支持lua语言语法高亮显示
    云服务器使用: 域名备案
    2-使用git管理一个单片机程序
    1-git的安装和基本使用
    编译lua固件NodeMcu 8266
    linux 安装Apache服务器
    2-STM32物联网开发WIFI(ESP8266)+GPRS(Air202)系统方案安全篇(监听Wi-Fi和APP的数据)
    Spring源码学习之:ClassLoader学习(3)
    Spring源码学习之:ClassLoader学习(2)
    Spring源码学习之:ClassLoader学习(1)
  • 原文地址:https://www.cnblogs.com/dogecheng/p/11930535.html
Copyright © 2011-2022 走看看