zoukankan      html  css  js  c++  java
  • Dataset和Dataloader

    torch.utils.data.Dataset与torch.utils.data.DataLoader的理解

    1. pytorch提供了一个数据读取的方法,其由两个类构成:torch.utils.data.Dataset和DataLoader
    2. 我们要自定义自己数据读取的方法,就需要继承torch.utils.data.Dataset,并将其封装到DataLoader中
    3. torch.utils.data.Dataset表示该数据集,继承该类可以重载其中的方法,实现多种数据读取及数据预处理方式
    4. torch.utils.data.DataLoader 封装了Data对象,实现单(多)进程迭代器输出数据集

    一、定义自己的Dataset (torch.utils.data.Dataset)

    1. 要自定义自己的Dataset类,至少要重载两个方法,__len__, __getitem__
    2. __len__返回的是数据集的大小
    3. __getitem__实现索引数据集中的某一个数据
    4. 除了这两个基本功能,还可以在__getitem__时对数据进行预处理,或者是直接在硬盘中读取数据,对于超大的数据集还可以使用lmdb来读取
    from torch.utils.data import DataLoader, Dataset
    import torch
    
    class MyDataset(Dataset):
        # TensorDataset继承Dataset, 重载了__init__, __getitem__, __len__
        # 实现将一组Tensor数据对封装成Tensor数据集
        # 能够通过index得到数据集的数据,能够通过len,得到数据集大小
    
        def __init__(self, data_tensor, target_tensor):
            self.data_tensor = data_tensor
            self.target_tensor = target_tensor
    
        def __getitem__(self, index):
            return self.data_tensor[index], self.target_tensor[index]
    
        def __len__(self):
            return self.data_tensor.size(0)
    
    # 生成数据
    data_tensor = torch.randn(4, 3)
    target_tensor = torch.rand(4)
    print('x:',data_tensor)
    print('y:',target_tensor)
    # 将数据封装成Dataset
    tensor_dataset = MyDataset(data_tensor, target_tensor)
    
    # 可使用索引调用数据
    print ('tensor_data[0]: ', tensor_dataset[0])
    print( 'len os tensor_dataset: ', len(tensor_dataset))

    输出:

    x: tensor([[ 1.2816,  0.8122,  0.1183],
            [ 1.2182, -0.1133,  0.5438],
            [-0.3239, -0.4611,  0.7439],
            [-0.0841, -0.7142, -0.1525]])
    y: tensor([0.7254, 0.3795, 0.0325, 0.2877])
    tensor_data[0]:  (tensor([1.2816, 0.8122, 0.1183]), tensor(0.7254))
    len os tensor_dataset:  4

    基于MovieLens数据集的定义
    class MovieLens20MDataset(torch.utils.data.Dataset):
        def __init__(self, dataset_path, sep=',', engine='c', header='infer'):
            data = pd.read_csv(dataset_path, sep=sep, engine=engine, header=None).to_numpy()[:, :3]
            self.items = data[:, :2].astype(np.int) - 1  # -1 because ID begins from 1
            self.targets = self.__preprocess_target(data[:, 2]).astype(np.float32)
            self.field_dims = np.max(self.items, axis=0) + 1
            print(self.field_dims)
            self.user_field_idx = np.array((0, ), dtype=np.long)
            self.item_field_idx = np.array((1,), dtype=np.long)
    
        def __len__(self):
            return self.targets.shape[0]
    
        def __getitem__(self, index):
            return self.items[index], self.targets[index]
    
        def __preprocess_target(self, target):
            target[target <= 3] = 0
            target[target > 3] = 1
            return target
    
    class MovieLens1MDataset(MovieLens20MDataset):
        def __init__(self, dataset_path):
            super().__init__(dataset_path, sep=',', engine='python', header=None)

    二、Dataloader使用 (torch.utils.data.Dataloader)

    1. Dataloader将Dataset或其子类封装成一个迭代器
    2. 这个迭代器可以迭代输出Dataset的内容
    3. 同时可以实现多进程、shuffle、不同采样策略,数据校对等等处理过程
    tensor_dataloader = DataLoader(tensor_dataset,   # 封装的对象
                                   batch_size=2,     # 输出的batchsize
                                   shuffle=True,     # 随机输出
                                   num_workers=0)    # 只有1个进程
    
    # 以for循环形式输出
    for data, target in tensor_dataloader: 
        print(data, target)
    print('----------------------------------------')
    # 输出一个batch
    print ('one batch tensor data: ', iter(tensor_dataloader).next())
    # 输出batch数量
    print ('len of batchtensor: ', len(list(iter(tensor_dataloader))))

    输出:

    tensor([[-0.3239, -0.4611,  0.7439],
            [ 1.2182, -0.1133,  0.5438]]) tensor([0.0325, 0.3795])
    tensor([[-0.0841, -0.7142, -0.1525],
            [ 1.2816,  0.8122,  0.1183]]) tensor([0.2877, 0.7254])
    ----------------------------------------
    one batch tensor data:  [tensor([[-0.3239, -0.4611,  0.7439],
            [ 1.2816,  0.8122,  0.1183]]), tensor([0.0325, 0.7254])]
    len of batchtensor:  2
     
    
    
  • 相关阅读:
    翻译MDN里js的一些方法属性
    ajax相关
    我的面试错题
    写代码通用思路
    工厂模式
    cookie & session
    X-UA-Compatible设置IE浏览器兼容模式
    [转]IE6/IE7/IE8/IE9中tbody的innerHTML不能赋值的完美解决方案
    EasyUseCase 一款脑图转化 Excel 测试用例工具 (1.2 版本升级)
    XMind2TestCase:一个高效测试用例设计的解决方案!
  • 原文地址:https://www.cnblogs.com/gczr/p/14351737.html
Copyright © 2011-2022 走看看