zoukankan      html  css  js  c++  java
  • 重新定义Pytorch中的TensorDataset,可实现transforms

    class TensorsDataset(torch.utils.data.Dataset):
    
        '''
        A simple loading dataset - loads the tensor that are passed in input. This is the same as
        torch.utils.data.TensorDataset except that you can add transformations to your data and target tensor.
        Target tensor can also be None, in which case it is not returned.
        '''
    
        def __init__(self, data_tensor, target_tensor=None, transforms=None, target_transforms=None):
            if target_tensor is not None:
                assert data_tensor.size(0) == target_tensor.size(0)
            self.data_tensor = data_tensor
            self.target_tensor = target_tensor
    
            if transforms is None:
                transforms = []
            if target_transforms is None:
                target_transforms = []
    
            if not isinstance(transforms, list):
                transforms = [transforms]
            if not isinstance(target_transforms, list):
                target_transforms = [target_transforms]
    
            self.transforms = transforms
            self.target_transforms = target_transforms
    
        def __getitem__(self, index):
    
            data_tensor = self.data_tensor[index]
            for transform in self.transforms:
                data_tensor = transform(data_tensor)
    
            if self.target_tensor is None:
                return data_tensor
    
            target_tensor = self.target_tensor[index]
            for transform in self.target_transforms:
                target_tensor = transform(target_tensor)
    
            return data_tensor, target_tensor
    
        def __len__(self):
            return self.data_tensor.size(0)
    
  • 相关阅读:
    我想逗你开心!
    java 操作mysql数据库
    ajaxTest.js
    [译] 如何在React中写出更优秀的代码
    Solaris系统磁盘镜像配置步骤
    初探c++11之for循环篇
    初探c++11之介绍篇
    003:STM32系列命名规则(转)
    006:__Main介绍(ADS下)(转)
    005:DIY 解析STM32启动过程(转)
  • 原文地址:https://www.cnblogs.com/marsggbo/p/10459235.html
Copyright © 2011-2022 走看看