zoukankan      html  css  js  c++  java
  • 龙良曲pytorch学习笔记_加载宝可梦数据集

      1 import torch
      2 import os,glob
      3 import random,csv
      4 
      5 from torch.utils.data import Dataset,DataLoader
      6 
      7 from torchvision import transforms
      8 from PIL import Image
      9 
     10 class Pokemon(Dataset):
     11     '''
     12         @param
     13         root:存储的根路径
     14         resize:将图片大小根据网络结构适配
     15         mode:train或者test模式
     16     '''
     17     def __init__(self,root,resize,mode):
     18         super(Pokemon,self).__init__()
     19         
     20         self.root = root
     21         self.resize = resize
     22         
     23         # 字典类型key:name value:label
     24         self.name2label = {}
     25         # listdir返回顺序不固定,用sorted将它固定,因为排序一次之后就固定了
     26         for name in sorted(os.listdir(os.path.join(root))):
     27             if not os.path.isdir(os.path.join(root,name)):
     28                 continue
     29                 
     30             self.name2label[name] = len(self.name2label.keys())
     31         
     32         # print(self.name2label)
     33         
     34         # image_path + image_label
     35         self.images,self.labels = self.load_csv('images.csv')
     36         
     37         if mode == 'train': # 60%
     38             self.images = self.images[:int(0.6*len(self.images))]
     39             self.labels = self.labels[:int(0.6*len(self.labels))]
     40         elif mode == 'val': # 20%
     41             self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
     42             self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
     43         elif mode == 'test': # 20% = 80% ->100%
     44             self.images = self.images[int(0.8*len(self.images)):]
     45             self.labels = self.labels[int(0.8*len(self.labels)):]
     46             
     47     def load_csv(self,filename):
     48         
     49         # 如果不存在再写入,存在的话直接读取就可以了
     50         if not os.path.exists(os.path.join(self.root,filename))
     51             images = []
     52             for name in self.name2label.keys():
     53                 # 'pokemon'\mewtwo\00001.png
     54                 images += glob.glob(os.path.join(self.root,name,'*.png'))
     55                 images += glob.glob(os.path.join(self.root,name,'*.jpg'))
     56                 images += glob.glob(os.path.join(self.root,name,'*.jpeg'))
     57                 
     58             # 1167,'pokemon\bulbasaur\00000000.png'
     59             print(len(images),images)
     60             
     61             random.shuffle(images)
     62             with open(os.path.join(self.root,filename),mode = 'w',newline='') as f:
     63                 writer = csv.writer(f)
     64                 for img in images: # 'pokemon\bulbasaur\00000000.png'
     65                     name = img.split(os.sep)[-2]
     66                     label = self.name2label[name]
     67                     # 'pokemon\bulbasaur\00000000.png',0
     68                     writer.writerow([img,label])
     69                 print('writen into csv file:',filename)
     70             
     71         # read from csv file
     72         images,labels = [],[]
     73         with open(os.path.join(self.root,filename))
     74             reader = csv.reader(f)
     75             for row in reader:
     76                 # 'pokemon\bulbasaur\00000000.png',0
     77                 img,label = row
     78                 label = int(label)
     79                 
     80                 images.append(img)
     81                 labels.append(label)
     82                 
     83         # 保证images和labels一一对应,长度相等
     84         assert len(images) == len(labels)
     85         return images,labels
     86             
     87     def __len__(self):
     88         
     89         return len(self.images)
     90         
     91     def denormalize(self,x_hat):
     92     
     93         mean=[0.485,0.456,0.406]
     94         std=[0.229,0.224,0.225]
     95         
     96         # x_hat = (x-mean)/std
     97         # x = x_hat*std+mean
     98         # x: [c,h,w]
     99         # mean: [3] --> [3,1,1]
    100         mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
    101         std  = torch.tensor(std).unsqueeze(1).unsqueeze(1)
    102         
    103         x = x_hat*std + mean
    104         
    105         return x
    106         
    107     
    108     def __getitem__(self,idx):
    109         # idx~[0~len(images)]
    110         # self.images,self.labels
    111         # img: pokemon\bulbasaur\00000000.png'
    112         # label: 0
    113         img,label = self.images[idx],self.labels[idx]
    114         
    115         tf = transforms.Compose([
    116             lambda x:Image.open(x).convert('RGB'), # string path --> image data
    117             transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
    118             transforms.RandomRotation(15),
    119             transforms.CenterCrop(self.resize),
    120             transforms.ToTensor(),
    121             transforms.Normalize(mean=[0.485,0.456,0.406],
    122                                  std=[0.229,0.224,0.225])
    123         ])
    124         
    125         img = tf(img)
    126         label = torch.tensor(label)
    127         
    128         return img,label
  • 相关阅读:
    再谈ORACLE CPROCD进程
    fopen()函数
    Java抓取网页数据(原网页+Javascript返回数据)
    Vmware ESX 5.0 安装与部署
    PostgreSQL服务端监听设置及client连接方法
    方向梯度直方图(HOG)和颜色直方图的一些比較
    Vim简明教程【CoolShell】
    FileSystemWatcher使用方法具体解释
    几种常见模式识别算法整理和总结
    ThreadPool.QueueUserWorkItem的性能问题
  • 原文地址:https://www.cnblogs.com/fxw-learning/p/12331522.html
Copyright © 2011-2022 走看看