from torch.utils.data import Dataset from torch.utils.data import DataLoader from torch.utils.data import sampler import numpy as np import torch class OwnDataset(Dataset): def __init__(self,x,y): self.x = x self.y = y return def __getitem__(self,index): return self.x[index], self.y[index] def __len__(self): return len(self.x) x_train = np.random.rand(100,2) y_train = np.random.randint(2, size = 100) x_train = torch.from_numpy(x_train) y_train = torch.from_numpy(y_train) dataset = OwnDataset(x_train, y_train) train_loader = DataLoader(dataset, batch_size = 40,shuffle = True) for epoch in range(2): for i,data in enumerate(train_loader): inputs, labels = data #接下来喂入模型中 print("i:",i) print("inputs:",inputs) print("len(inputs):",len(inputs))
相关用法参考官网:
https://pytorch.org/docs/stable/data.html?highlight=dataset#torch.utils.data.Dataset