zoukankan      html  css  js  c++  java
  • PyTorch笔记之 Dataset 和 Dataloader

    简介

    在 PyTorch 中,我们的数据集往往会用一个类去表示,在训练时用 Dataloader 产生一个 batch 的数据

    https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

    比如官方例子中对 CIFAR10 图像数据集进行分类,就有用到这样的操作,具体代码如下所示

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                             shuffle=False, num_workers=2)

    简单说,用 一个类 抽象地表示数据集,而 Dataloader 作为迭代器,每次产生一个 batch 大小的数据,节省内存

    Dataset

    Dataset 是 PyTorch 中用来表示数据集的一个抽象类,我们的数据集可以用这个类来表示,至少覆写下面两个方法即可

    这返回数据前可以进行适当的数据处理,比如将原文用一串数字序列表示

    • __len__:数据集大小
    • __getitem__:实现这个方法后,可以通过下标的方式( dataset[i] )的来取得第 $i$ 个数据

    下面我们来为编写一个类表示一个情感二分类数据集,继续用苏神整理的数据集

    https://github.com/bojone/bert4keras/tree/master/examples/datasets

    数据集没有表头,只有2列,一列是评论(文本),另一列是标签,以制表符进行分隔

    from torch.utils.data import Dataset, DataLoader
    import pandas as pd
    
    class SentimentDataset(Dataset):
        def __init__(self, path_to_file):
            self.dataset = pd.read_csv(path_to_file, sep="	", names=["text", "label"])
        def __len__(self):
            return len(self.dataset)
        def __getitem__(self, idx):
            text = self.dataset.loc[idx, "text"]
            label = self.dataset.loc[idx, "label"]
            sample = {"text": text, "label": label}
            return sample

    Dataloader

    基本使用

    Dataloader 就是一个迭代器,最基本的使用就是传入一个 Dataset 对象,它就会根据参数 batch_size 的值生成一个 batch 的数据

    if __name__ == "__main__":
        sentiment_dataset = SentimentDataset("sentiment.test.data")
        sentiment_dataloader = DataLoader(sentiment_dataset, batch_size=4, shuffle=True, num_workers=2)
        for idx, batch_samples in enumerate(sentiment_dataloader):
            text_batchs, text_labels = batch_samples["text"], batch_samples["label"]
            print(text_batchs)

    Sampler

    PyTorch 提供了 Sampler 模块,用来对数据进行采样,可以在 DataLoader 的通过 sampler 参数调用

    一般我们的加载训练集的 dataloader ,shuffle参数都会设置为True ,这时候使用了一个默认的采样器——RandomSampler

    当 shuffle 设置为 False 时,默认使用的是 SequencetialSampler,其实就是按顺序取出数据集中的元素

    在 PyTorch 中默认实现了以下 Sampler,如果我们要使用别的 Sampler, shuffle 要设置为 False

    • SequentialSampler
    • RandomSampler
    • WeightedSampler
    • SubsetRandomSampler

    SubsetRandomSampler 常用来将数据集划分为训练集和测试集,比如这里就训练集和测试集按7:3 进行分割

    n_train = len(sentiment_train_set)
    split = n_train // 3
    
    indices
    = list(range(n_train)) train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
    train_loader
    = DataLoader(sentiment_train_set, sampler=train_sampler, shuffle=False) valid_loader = DataLoader(sentiment_train_set, sampler=valid_sampler, shuffle=False)

    具体推荐下面的博文,讲得挺详细的

    一文弄懂Pytorch的DataLoader, DataSet, Sampler之间的关系

    https://www.cnblogs.com/marsggbo/p/11541054.html

    Pytorch Sampler详解

    https://www.cnblogs.com/marsggbo/p/11541054.html

    collate_fn

    可以用来进行一些数据处理,比如在文本任务中,一般由于文本长度不一致,我们需要进行截断或者填充。对于图片,我们则希望它们有同样的尺寸

    我么可以编写一个函数,然后用这个参数调用它,下面是一个简单的例子,我们把文本截断成只有10个字符

    def truncate(data_list):
      """传进一个batch_size大小的数据"""
      for data in data_list:
        text = data["text"]
        data["text"]=text[:10]
      return data_list
    
    test_loader = DataLoader(sentiment_train_set, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=truncate)

    我们可以看看返回的内容是否已经经过截断了

    for i in test_loader:
      print(i)
      break

    这时候返回的是一个列表而不是字典了,其中一个 batch 的返回结果如下,我们可以看到这里一个样本放在了一个字典中

    [{'text': '看了一个通宵,实在是', 'label': 1}, 。。。, {'text': '看了携程的其他用户评', 'label': 0}]

    下面是没有使用 collate_fn 的返回结果,它会将数据和标签分开,存放在一起,如下所示

    {

    'text':['3月1号订的,3月15号还没到货 客服每天说下个工作日能到货已经连续5天了 我无语。想早点儿看这本书的人还是去陶宝或卓越上订吧,尤其是广东省的朋友.当当送货太没保证了.',。。。, '非常纯朴的故事,但包含了主人公坎坷的一生,活着就是痛苦,不得不佩服生命的韧性'],

    'label': tensor([1, 。。。, 0])

    }

  • 相关阅读:
    把git项目放到个人服务器上
    关于fcitx无法切换输入法的问题解决
    博客变迁通知
    (欧拉回路 并查集 别犯傻逼的错了) 7:欧拉回路 OpenJudge 数据结构与算法MOOC / 第七章 图 练习题(Excercise for chapter7 graphs)
    (并查集) HDU 1856 More is better
    (并查集 不太会) HDU 1272 小希的迷宫
    (并查集 注意别再犯傻逼的错了) HDU 1213 How Many Tables
    (最小生成树 Kruskal算法) 51nod 1212 无向图最小生成树
    (并查集) HDU 1232 畅通工程
    (最小生成树 Prim) HDU 1233 还是畅通工程
  • 原文地址:https://www.cnblogs.com/dogecheng/p/11930535.html
Copyright © 2011-2022 走看看