zoukankan      html  css  js  c++  java
  • Pytorch-创建图片的dataset和dataloader和数据过采样

    数据集的格式如下:

    datasets----train文件夹(WA和WKY文件夹,里面分别存放了200张图片)

                 ----test文件夹(WA和WKY文件夹,里面分别存放了100张图片)

    每一张图片都有自己的文件名,train中WA的图片标签为0,WKY的图片标签为1。

    1.构建Dataset

     1 import os
     2 import random
     3 import torch
     4 from torch.utils.data import Dataset
     5 import torchvision 
     6 import imghdr          
     7 from PIL import Image  
     8 import matplotlib.pyplot as plt
     9  
    10  
    11 class MedicalDataset(Dataset):
    12     def __init__(self, root, split, data_ratio=1.0):
    13         self.img_list = list()                  #self.img_list存储的是所有.jpg文件的绝对路径名 
    14         self.cls_list = list()                  #存储label索引   
    15         self.cls_num = dict()                   #每个类别的样本个数  
    16         
    17         
    18         classes = ['WA', 'WKY']
    19         for idx, cls in enumerate(classes): 
    20             img_list = sorted(os.listdir(os.path.join(root, split, cls)))    
    21             self.cls_num[cls] = len(img_list)  
    22             for img_fp in img_list:                 #取出每一个文件名 
    23                 self.img_list.append(os.path.join(root, split, cls, img_fp))
    24                 self.cls_list.append(idx)
    25         
    26         if data_ratio < 1.0:
    27             shuffled_idxs = list(range(len(self.img_list)))
    28             random.shuffle(shuffled_idxs)
    29             num_samples = round(data_ratio * len(self.img_list))
    30             img_list = list()
    31             cls_list = list()
    32             for idx in shuffled_idxs[:num_samples]:
    33                 img_list.append(self.img_list[idx])
    34                 cls_list.append(self.cls_list[idx])
    35             self.img_list = img_list
    36             self.cls_list = cls_list
    37              
    38         if split == 'train':
    39             self.trans = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
    40                                                          torchvision.transforms.RandomCrop(224),
    41                                                          torchvision.transforms.RandomHorizontalFlip(),
    42                                                          torchvision.transforms.ToTensor(),
    43                                                          torchvision.transforms.Normalize([0.485, 0.456, 0.406],
    44                                                                                           [0.229, 0.224, 0.225])
    45                                                          ])
    46         else:
    47             self.trans = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
    48                                                          torchvision.transforms.CenterCrop(224),
    49                                                          torchvision.transforms.ToTensor(),
    50                                                          torchvision.transforms.Normalize([0.485, 0.456, 0.406],
    51                                                                                           [0.229, 0.224, 0.225])
    52                                                          ])
    53               
    54     def _getdata(self):
    55         return self.img_list
    56 
    57     def __getitem__(self, index):
    58         name = self.img_list[index]
    59         img = Image.open(name)
    60         img = self.trans(img) 
    61         label = self.cls_list[index]
    62         return img, label               #这里必须返回img和label,否则后面取出来的格式不对
    63 
    64     def __len__(self):
    65         return len(self.img_list)

    如果想查看图片的话,

     1 from torch.utils.data import DataLoader
     2 
     3 dataset = MedicalDataset('datasets/', 'train')
     4 print('dataset: ', dataset)
     5 print('len= ', dataset.__len__())              # 训练集总共样本数:400
     6 
     7 img, label = dataset.__getitem__(-1)
     8 print('img.shape= ', img.shape)                # torch.Size([3, 224, 224])
     9 print('label= ', label)                        # 1
    10 
    11 loader = DataLoader(dataset, batch_size=16, shuffle=True)        #loader中每次迭代的元素就是item返回的值 
    12 print(next(iter(loader))[0].shape, next(iter(loader))[1].shape)  #torch.Size([16, 3, 224, 224]), torch.Size([16])
    13 
    14 #显示一张图片
    15 unloader = torchvision.transforms.ToPILImage()  # .ToPILImage() 把tensor或数组转换成图像
    16 
    17 def imshow(tensor, title=None):
    18     image = tensor.cpu().clone()    # we clone the tensor to not do changes on it
    19     image = image.squeeze(0)
    20 
    21     image = unloader(image)         # tensor转换成图像
    22     plt.imshow(image)
    23     if title is not None:
    24         plt.title(title)
    25     plt.pause(1)                    # 只是延迟显示作用
    26 
    27 plt.figure()
    28 imshow(img, title='Image')

    2.创建DataLoader

    parser.add_argument("--dataset-path", default='./datasets', type=str, help="Path of the trainset.")

     1 # 创建数据集
     2 train_dataset = MedicalDataset(args.dataset_path, 'train')
     3 test_dataset = MedicalDataset(args.dataset_path, 'test')
     4 print(len(train_dataset), len(test_dataset))       # 训练集400,测试集200
     5 # 把训练集分割成训练集和验证集,比例为8:2
     6 train_size = int(0.8 * len(train_dataset))
     7 val_size = len(train_dataset) - train_size
     8 train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size]) 
     9 
    10 train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
    11 val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    12 test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)    

    3.oversample过采样

    假如train中WA和WKY的数据不平衡(eg训练集中WA有1555张,WKY有496张,验证集中WA有223张,WKY有70张,测试集中WA有444张,WKY有142张),需要对WKY的训练集和验证集进行过采样(不是单纯的重复,使用数据增强),测试集不用管。

      1 import os
      2 import random
      3 import torch
      4 from torch.utils.data import Dataset
      5 import torchvision
      6 from PIL import Image
      7 
      8 class MedicalDataset(Dataset):
      9     def __init__(self, root, split, data_ratio=1.0, ret_name=False):
     10         assert split in ['train', 'val', 'test']
     11         self.ret_name = ret_name
     12         self.cls_to_ind_dict = dict()
     13         self.ind_to_cls_dict = list()
     14         self.img_list = list()
     15         self.cls_list = list()
     16         self.cls_num = dict()
     17       
     18         classes = ['WA', 'WKY']
     19         if split=='test':
     20              
    21
    for idx, cls in enumerate(classes): 22 self.cls_to_ind_dict[cls] = idx 23 self.ind_to_cls_dict.append(cls) 24 img_list = sorted(os.listdir(os.path.join(root, split, cls))) 25 self.cls_num[cls] = len(img_list) 26 for img_fp in img_list: 27 self.img_list.append(os.path.join(root, split, cls, img_fp)) 28 self.cls_list.append(idx) 29 30 31 else: 32 img_list_temp, cls_list_temp = [],[] 33
    34
    for idx, cls in enumerate(classes): 35 self.cls_to_ind_dict[cls] = idx 36 self.ind_to_cls_dict.append(cls) 37 if cls == 'WA': #WA的训练集数量不用扩 38 img_list = sorted(os.listdir(os.path.join(root, split, cls))) 39 self.cls_num[cls] = len(img_list) 40 for img_fp in img_list: 41 self.img_list.append(os.path.join(root, split, cls, img_fp)) 42 self.cls_list.append(idx) 43 print(cls, '=======================') 44 print(len(self.img_list), len(self.cls_list)) 45 46 else: 47 img_list = sorted(os.listdir(os.path.join(root, split, cls))) 48 49 for img_fp in img_list: 50 img_list_temp.append(os.path.join(root, split, cls, img_fp)) 51 cls_list_temp.append(idx) 52 53 img_list_temp = [val for val in img_list_temp for i in range(3)] #将原来的img_list重复三遍 54 cls_list_temp = [val for val in cls_list_temp for i in range(3)] 55 self.cls_num[cls] = len(img_list_temp) #记录每个类别的新数目 56 57 print(cls, '=======================') 58 print(len(img_list_temp), len(cls_list_temp)) 59 60 self.img_list = self.img_list + img_list_temp 61 self.cls_list = self.cls_list + cls_list_temp 62 63 print(len(self.img_list), len(self.cls_list)) 64 65 66 # 强制水平翻转 67 self.trans0 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 68 torchvision.transforms.RandomCrop(224), 69 torchvision.transforms.RandomHorizontalFlip(p=1), 70 torchvision.transforms.ToTensor(), 71 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 72 [0.229, 0.224, 0.225]) 73 ]) 74 # 强制垂直翻转 75 self.trans1 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 76 torchvision.transforms.RandomCrop(224), 77 torchvision.transforms.RandomVerticalFlip(p=1), 78 torchvision.transforms.ToTensor(), 79 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 80 [0.229, 0.224, 0.225]) 81 ]) 82 # 旋转-90~90 83 self.trans2 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 84 torchvision.transforms.RandomCrop(224), 85 torchvision.transforms.RandomRotation(90), 86 torchvision.transforms.ToTensor(), 87 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 88 [0.229, 0.224, 0.225]) 89 ]) 90 91 # 亮度在0-2之间增强,0是原图 92 self.trans3 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 93 torchvision.transforms.RandomCrop(224), 94 torchvision.transforms.ColorJitter(brightness=1), 95 torchvision.transforms.ToTensor(), 96 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 97 [0.229, 0.224, 0.225]) 98 ]) 99 # 修改对比度,0-2之间增强,0是原图 100 self.trans4 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 101 torchvision.transforms.RandomCrop(224), 102 torchvision.transforms.ColorJitter(contrast=2), 103 torchvision.transforms.ToTensor(), 104 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 105 [0.229, 0.224, 0.225]) 106 ]) 107 # 颜色变化 108 self.trans5 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 109 torchvision.transforms.RandomCrop(224), 110 torchvision.transforms.ColorJitter(hue=0.5), 111 torchvision.transforms.ToTensor(), 112 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 113 [0.229, 0.224, 0.225]) 114 ]) 115 # 混合 116 self.trans6 = torchvision.transforms.Compose([torchvision.transforms.Resize(256), 117 torchvision.transforms.RandomCrop(224), 118 torchvision.transforms.ColorJitter(brightness=1, contrast=2, hue=0.5), 119 torchvision.transforms.ToTensor(), 120 torchvision.transforms.Normalize([0.485, 0.456, 0.406], 121 [0.229, 0.224, 0.225]) 122 ]) 123 self.trans_list = [self.trans0, self.trans1, self.trans2, self.trans3, self.trans4, self.trans5, self.trans6] 124 125 126 127 def __getitem__(self, index): 128 name = self.img_list[index] 129 img = Image.open(name) 130 num = random.randint(0, 6) 131 img = self.trans_list[num](img) 132 label = self.cls_list[index] 133 if self.ret_name: 134 return img, label, name 135 else: 136 return img, label 137 138 def __len__(self): 139 return len(self.img_list)

     扩展后WKY的训练集个数为1488,验证集个数为210,测试集个数依然是142。通过过采样,无论WA做正例还是负例,得到的灵敏度都相似,不会有非常大的差别。

  • 相关阅读:
    排序算法 之 冒泡排序 插入排序 希尔排序 堆排序
    DataStructure之线性表以及其实现
    使用可重入函数进行更安全的信号处理
    内存经济学
    电脑通用技能
    循环套餐的逻辑
    占用了多少内存
    索引的用法
    电脑的眼缘
    字符串积木
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13949136.html
Copyright © 2011-2022 走看看