zoukankan      html  css  js  c++  java
  • pytorch Dataset Dataloader用法(一个示例)

    from torch.utils.data import Dataset
    from torch.utils.data import DataLoader
    from torch.utils.data import sampler
    import numpy as np
    import torch
    
    class OwnDataset(Dataset):
        def __init__(self,x,y):
            self.x = x
            self.y = y
            return
        
        def __getitem__(self,index):
            return self.x[index], self.y[index]
        
        def __len__(self):
            return len(self.x)
    
    x_train = np.random.rand(100,2)
    y_train = np.random.randint(2, size = 100)
    
    x_train = torch.from_numpy(x_train)
    y_train = torch.from_numpy(y_train)
    
    dataset = OwnDataset(x_train, y_train)
    train_loader = DataLoader(dataset, batch_size = 40,shuffle = True)
    
    for epoch in range(2):
        for i,data in enumerate(train_loader):
            inputs, labels = data  #接下来喂入模型中
            print("i:",i)
            print("inputs:",inputs)
            print("len(inputs):",len(inputs))

     相关用法参考官网:

    https://pytorch.org/docs/stable/data.html?highlight=dataset#torch.utils.data.Dataset

  • 相关阅读:
    SQL Server如何固定执行计划
    领导修炼
    content management system
    npm和bower
    web开发workflow
    偏执狂
    website project team member 角色及开发过程概念图
    website architecture
    王道霸道
    design pattern及其使用
  • 原文地址:https://www.cnblogs.com/qiezi-online/p/14098960.html
Copyright © 2011-2022 走看看