zoukankan      html  css  js  c++  java
  • 重新定义Pytorch中的TensorDataset,可实现transforms

    class TensorsDataset(torch.utils.data.Dataset):
    
        '''
        A simple loading dataset - loads the tensor that are passed in input. This is the same as
        torch.utils.data.TensorDataset except that you can add transformations to your data and target tensor.
        Target tensor can also be None, in which case it is not returned.
        '''
    
        def __init__(self, data_tensor, target_tensor=None, transforms=None, target_transforms=None):
            if target_tensor is not None:
                assert data_tensor.size(0) == target_tensor.size(0)
            self.data_tensor = data_tensor
            self.target_tensor = target_tensor
    
            if transforms is None:
                transforms = []
            if target_transforms is None:
                target_transforms = []
    
            if not isinstance(transforms, list):
                transforms = [transforms]
            if not isinstance(target_transforms, list):
                target_transforms = [target_transforms]
    
            self.transforms = transforms
            self.target_transforms = target_transforms
    
        def __getitem__(self, index):
    
            data_tensor = self.data_tensor[index]
            for transform in self.transforms:
                data_tensor = transform(data_tensor)
    
            if self.target_tensor is None:
                return data_tensor
    
            target_tensor = self.target_tensor[index]
            for transform in self.target_transforms:
                target_tensor = transform(target_tensor)
    
            return data_tensor, target_tensor
    
        def __len__(self):
            return self.data_tensor.size(0)
    
  • 相关阅读:
    Python列表去重的三种方法
    关于Python的 a, b = b, a+b
    Python爬取B站视频信息
    Linux文件管理命令
    (一)MySQL学习笔记
    Linux特殊字符含义
    在父容器div中图片下方有一条空隙问题
    对Json的各种遍历方法
    for循环使用append问题
    IE兼容性
  • 原文地址:https://www.cnblogs.com/marsggbo/p/10459235.html
Copyright © 2011-2022 走看看