zoukankan      html  css  js  c++  java
  • 深度网络学习-PyTorch_自定义Datsset

    PyTorch中的数据

    Dataset Dataloader transformer
    数据集的格式
    

    分类生成标签

     制作训练和验证数据的.txt文件
     
    #!/usr/bin/env python3
    # -*- coding: UTF-8 -*-
    
    import os 
    
    def list_dir(path):
        res =dict()
        for category in os.listdir(path):
            temp_dir = os.path.join(path, category)
            if os.path.isdir(temp_dir):
                temp =os.listdir(temp_dir)
                leaf_file = [os.path.join("/",category,data) for data in temp]
                res[category]=leaf_file
        return res
    
    def get_text(path,fil_dict):
        relation = {"dog":1,"cat":2}
        file_nm = os.path.split(path)[-1]+".txt"
        with open(os.path.join(path,file_nm),mode="w",encoding="utf-8") as f:
            for category_key in fil_dict:
                for label_file in  fil_dict[category_key]:
                    labe_res= label_file +  "	"+ str(relation[category_key] )
                    print( labe_res  )
                    f.write(labe_res+"
    ")
    
    
    if __name__ == '__main__':
        data_dir = "./pytorch/data/train"
        fil = list_dir(data_dir)
        get_text(data_dir,fil)
    

    自定义Dataset

     自定义Dataset,继承Dataset, 重写抽象方法:__init__, __len()__, __getitem()__
    
    #!/usr/bin/env python3
    # -*- coding: UTF-8 -*-
    
    import os
    import cv2
    import numpy as np
    import torch
    from torch.utils.data import Dataset, DataLoader
    from torchvision.transforms import transforms
    
    
    # step1: 定义MyDataset类, 继承Dataset, 重写抽象方法:__init__, __len()__, __getitem()__
    class MyDataset(Dataset):
        def __init__(self, root_dir, names_file, transform=None):
            self.root_dir = root_dir
            self.names_file = names_file
            self.transform = transform
            self.size = 0
            self.names_list = []
    
            if not os.path.isfile(self.names_file):
                print(self.names_file + ' ## does not exist!')
            file = open(self.names_file)
            for f in file:
                self.names_list.append(f)
                self.size += 1
    
        def __len__(self):
            return self.size
    
        def __getitem__(self, idx):
            image_path = self.root_dir + self.names_list[idx].split('	')[0]
            if not os.path.isfile(image_path):
                print(image_path +  '@does not exist!')
                return None
            image = cv2.imread(image_path) 
    
            label = int(self.names_list[idx].split('	')[1])
    
            sample = {'image': image, 'label': label}
            if self.transform:
                sample = self.transform(sample)
            return sample
    
    # # 变换Resize
    class Resize(object):
    
        def __init__(self, output_size: tuple):
            self.output_size = output_size
    
        def __call__(self, sample):
            # 图像
            image = sample['image']
            # 对图像进行缩放
            image_new =  cv2.resize(image, self.output_size)
            return {'image': image_new, 'label': sample['label']}
    
    # # 变换ToTensor
    class ToTensor(object):
    
        def __call__(self, sample):
            image = sample['image']
            image_new = np.transpose(image, (2, 0, 1))
            return {'image': torch.from_numpy(image_new),
                    'label': sample['label']}
    
    
    if __name__ == "__main__":
        train_dataset = MyDataset(root_dir='./pytorch/data/train',
                              names_file='./pytorch/data/train/train.txt',
                              transform=transforms.Compose( [
    						  Resize((224,224)),ToTensor()
    						  ])
                              )
        for (cnt,i) in enumerate(train_dataset):
            image = i['image']
            label = i['label']
            print(label)
        trainset_dataloader = DataLoader(dataset=train_dataset,
                                     batch_size=4,
                                     shuffle=True,
                                     num_workers=4)
        
        for i_batch, sample_batch in enumerate(trainset_dataloader):
            images_batch, labels_batch = sample_batch['image'], sample_batch['label']
            print(labels_batch.shape,labels_batch.dtype)
            print(images_batch.shape,images_batch.dtype)
            print(labels_batch)
            print(images_batch)
    

    参考

         https://pytorch.org/docs/stable/data.html
  • 相关阅读:
    loj2042 「CQOI2016」不同的最小割
    loj2035 「SDOI2016」征途
    luogu2120 [ZJOI2007]仓库建设
    luogu3195 [HNOI2008]玩具装箱TOY
    51nod 1069 Nim游戏 + BZOJ 1022: [SHOI2008]小约翰的游戏John(Nim游戏和Anti-Nim游戏)
    HDU 5723 Abandoned country(最小生成树+边两边点数)
    BZOJ 1497: [NOI2006]最大获利(最大权闭合图)
    51nod 1615 跳跃的杰克
    SPOJ 839 Optimal Marks(最小割的应用)
    UVa 11107 生命的形式(不小于k个字符串中的最长子串)
  • 原文地址:https://www.cnblogs.com/ytwang/p/15239433.html
Copyright © 2011-2022 走看看