zoukankan      html  css  js  c++  java
  • Pytorch划分数据集的方法:torch.utils.data.Subset

    torch.utils.data

     

    Pytorch提供的对数据集进行操作的函数详见:https://pytorch.org/docs/master/data.html#torch.utils.data.SubsetRandomSampler

    torch的这个文件包含了一些关于数据集处理的类:

    • class torch.utils.data.Dataset: 一个抽象类, 所有其他类的数据集类都应该是它的子类。而且其子类必须重载两个重要的函数:len(提供数据集的大小)、getitem(支持整数索引)。
    • class torch.utils.data.TensorDataset: 封装成tensor的数据集,每一个样本都通过索引张量来获得。
    • class torch.utils.data.ConcatDataset: 连接不同的数据集以构成更大的新数据集。
    • class torch.utils.data.Subset(dataset, indices): 获取指定一个索引序列对应的子数据集。
    • class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None): 数据加载器。组合了一个数据集和采样器,并提供关于数据的迭代器。
    • torch.utils.data.random_split(dataset, lengths): 按照给定的长度将数据集划分成没有重叠的新数据集组合。
    • class torch.utils.data.Sampler(data_source):所有采样的器的基类。每个采样器子类都需要提供 iter 方-法以方便迭代器进行索引 和一个 len方法 以方便返回迭代器的长度。
    • class torch.utils.data.SequentialSampler(data_source):顺序采样样本,始终按照同一个顺序。
    • class torch.utils.data.RandomSampler(data_source):无放回地随机采样样本元素。
    • class torch.utils.data.SubsetRandomSampler(indices):无放回地按照给定的索引列表采样样本元素。
    • class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照给定的概率来采样样本。
    • class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一个batch中封装一个其他的采样器。
    • class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采样器可以约束数据加载进数据集的子集。

    示例


    下面Pytorch提供的划分数据集的方法以示例的方式给出:

    SubsetRandomSampler

    
    
    dataset = MyCustomDataset(my_path)
    batch_size = 16
    validation_split = .2
    shuffle_dataset = True
    random_seed= 42
    
    # Creating data indices for training and validation splits:
    dataset_size = len(dataset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if shuffle_dataset :
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    
    # Creating PT data samplers and loaders:
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)
    
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                               sampler=train_sampler)
    validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                    sampler=valid_sampler)
    
    # Usage Example:
    num_epochs = 10
    for epoch in range(num_epochs):
        # Train:   
        for batch_index, (faces, labels) in enumerate(train_loader):
          
     

    random_split

    train_size = int(0.8 * len(full_dataset))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
     

    参考:

    https://www.cnblogs.com/marsggbo/p/10496696.html

    https://stackoverflow.com/questions/50544730/how-do-i-split-a-custom-dataset-into-training-and-test-datasets

    https://likewind.top/2019/02/01/Pytorch-dataprocess/

    https://blog.csdn.net/xholes/article/details/81410834

  • 相关阅读:
    LeetCode 842. Split Array into Fibonacci Sequence
    LeetCode 1087. Brace Expansion
    LeetCode 1219. Path with Maximum Gold
    LeetCode 1079. Letter Tile Possibilities
    LeetCode 1049. Last Stone Weight II
    LeetCode 1046. Last Stone Weight
    LeetCode 1139. Largest 1-Bordered Square
    LeetCode 764. Largest Plus Sign
    LeetCode 1105. Filling Bookcase Shelves
    LeetCode 1027. Longest Arithmetic Sequence
  • 原文地址:https://www.cnblogs.com/Bella2017/p/11791216.html
Copyright © 2011-2022 走看看