zoukankan      html  css  js  c++  java
  • pytorch Dataset Dataloader用法(一个示例)

    from torch.utils.data import Dataset
    from torch.utils.data import DataLoader
    from torch.utils.data import sampler
    import numpy as np
    import torch
    
    class OwnDataset(Dataset):
        def __init__(self,x,y):
            self.x = x
            self.y = y
            return
        
        def __getitem__(self,index):
            return self.x[index], self.y[index]
        
        def __len__(self):
            return len(self.x)
    
    x_train = np.random.rand(100,2)
    y_train = np.random.randint(2, size = 100)
    
    x_train = torch.from_numpy(x_train)
    y_train = torch.from_numpy(y_train)
    
    dataset = OwnDataset(x_train, y_train)
    train_loader = DataLoader(dataset, batch_size = 40,shuffle = True)
    
    for epoch in range(2):
        for i,data in enumerate(train_loader):
            inputs, labels = data  #接下来喂入模型中
            print("i:",i)
            print("inputs:",inputs)
            print("len(inputs):",len(inputs))

     相关用法参考官网:

    https://pytorch.org/docs/stable/data.html?highlight=dataset#torch.utils.data.Dataset

  • 相关阅读:
    Container(容器)
    version ctrl
    url和uri的区别
    Injector
    build tool
    变量
    python中break和continue的区别
    同步代码块以及同步方法之间的区别以及联系
    写在前面
    WebService入门
  • 原文地址:https://www.cnblogs.com/qiezi-online/p/14098960.html
Copyright © 2011-2022 走看看