zoukankan      html  css  js  c++  java
  • 重新定义Pytorch中的TensorDataset,可实现transforms

    class TensorsDataset(torch.utils.data.Dataset):
    
        '''
        A simple loading dataset - loads the tensor that are passed in input. This is the same as
        torch.utils.data.TensorDataset except that you can add transformations to your data and target tensor.
        Target tensor can also be None, in which case it is not returned.
        '''
    
        def __init__(self, data_tensor, target_tensor=None, transforms=None, target_transforms=None):
            if target_tensor is not None:
                assert data_tensor.size(0) == target_tensor.size(0)
            self.data_tensor = data_tensor
            self.target_tensor = target_tensor
    
            if transforms is None:
                transforms = []
            if target_transforms is None:
                target_transforms = []
    
            if not isinstance(transforms, list):
                transforms = [transforms]
            if not isinstance(target_transforms, list):
                target_transforms = [target_transforms]
    
            self.transforms = transforms
            self.target_transforms = target_transforms
    
        def __getitem__(self, index):
    
            data_tensor = self.data_tensor[index]
            for transform in self.transforms:
                data_tensor = transform(data_tensor)
    
            if self.target_tensor is None:
                return data_tensor
    
            target_tensor = self.target_tensor[index]
            for transform in self.target_transforms:
                target_tensor = transform(target_tensor)
    
            return data_tensor, target_tensor
    
        def __len__(self):
            return self.data_tensor.size(0)
    
  • 相关阅读:
    SQL SERVER 表分区技术
    T-SQL 查询某个表的约束(包含触发器trigger)
    该数据库标记为 SUSPECT解决方法
    DevExpressGridHelper
    DevExpress MVC Gridview 把header caption 替换为 CheckBox (类似select all)
    CSRF漏洞
    XSS闯关挑战(1-15)
    Nginx 解析漏洞
    Nginx 配置错误导致漏洞
    Nginx 文件名逻辑漏洞(CVE-2013-4547)
  • 原文地址:https://www.cnblogs.com/marsggbo/p/10459235.html
Copyright © 2011-2022 走看看