zoukankan      html  css  js  c++  java
  • torch_13_自定义数据集实战

    1.将图片的路径和标签写入csv文件并实现读取

     1  # 创建一个文件,包含image,存放方式:label pokemeon\mew\0001.jpg,0
     2     def load_csv(self,filename):
     3         if not os.path.exists(os.path.join(self.root,filename)):
     4             images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断
     5             for name in self.name2label.keys():
     6                 # pokemeon\mew\0001.jpg mew可以通过字典查看其类别
     7                 images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径
     8                 images += glob.glob(os.path.join(self.root,name,'*.jpg'))
     9             random.shuffle(images)
    10             with open(os.path.join(self.root,filename),'w') as f:
    11                 writer = csv.writer(f)
    12                 for img in images:
    13                     name = img.split(os.sep)
    14                     label = self.name2label[name[-2]]
    15                     writer.writerow([img,label])
    16 
    17          # 从csv中读取文件
    18         images, labels = [], []
    19         with open(os.path.join(self.root,filename),'r') as f:
    20             reader = csv.reader(f)
    21             for row in reader:
    22                 img,label = row
    23                 label = int(label)
    24                 images.append(img)
    25                 labels.append(label)
    26         assert len(images) == len(labels) # 保证数据长度一致
           return images,labels

     2.加载自定义数据集

      1 """
      2 自定义数据集
      3 image_resize
      4 data argumentation(数据增强):Rotate,crop
      5 normalize:mean,std
      6 ToTensor
      7 
      8 """
      9 import torch
     10 import os,glob
     11 import random,csv
     12 from torch.utils.data import Dataset,DataLoader
     13 from torchvision import transforms
     14 from PIL import Image
     15 import visdom
     16 
     17 
     18 class Pokemon(Dataset):
     19     def __init__(self,root,resize,mode):
     20         super(Pokemon,self).__init__()
     21         self.root = root
     22         self.resize = resize
     23         self.name2label = {}
     24         for name in os.listdir(os.path.join(root)): #把文件和dir都会加载近来
     25             if not sorted(os.path.isdir(os.path.join(root,name))):#排序后,文件夹顺序固定了
     26                 continue
     27             self.name2label[name] = len(self.name2label.keys())
     28         # name2label:{文件夹名,类别编号}
     29         # 创建一个文件,包含image,存放方式:label pokemeon\mew\0001.jpg,0
     30         self.images, self.labels = self.load_csv('images.csv')
     31         # 对数据进行裁剪,mode:train-0.6,validation-0.2,test-0.2数据量是不同的
     32         if mode == 'train':
     33             self.images = self.images[:,int(len(self.images)*0.6)]
     34             self.labels = self.labels[:,int(len(self.images)*0.6)]
     35         elif mode == 'val':
     36             self.images = self.images[int(len(self.images)*0.6):int(len(self.images)*0.8)]
     37             self.labels = self.labels[int(len(self.labels)*0.6):int(len(self.labels)*0.8)]
     38         else:
     39             self.images = self.images[int(len(self.images) * 0.8):]
     40             self.labels = self.labels[int(len(self.labels) * 0.8):]
     41 
     42     def load_csv(self,filename):
     43         if not os.path.exists(os.path.join(self.root,filename)):
     44             images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断
     45             for name in self.name2label.keys():
     46                 # pokemeon\mew\0001.jpg mew可以通过字典查看其类别
     47                 images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径
     48                 images += glob.glob(os.path.join(self.root,name,'*.jpg'))
     49             random.shuffle(images)
     50             with open(os.path.join(self.root,filename),'w') as f:
     51                 writer = csv.writer(f)
     52                 for img in images:
     53                     name = img.split(os.sep)
     54                     label = self.name2label[name[-2]]
     55                     writer.writerow([img,label])
     56          # 从csv中读取文件
     57         images, labels = [], []
     58         with open(os.path.join(self.root,filename),'r') as f:
     59             reader = csv.reader(f)
     60             for row in reader:
     61                 img,label = row
     62                 label = int(label)
     63                 images.append(img)
     64                 labels.append(label)
     65         assert len(images) == len(labels) # 保证数据长度一致
     66         return images,labels
     67 
     68     def __len__(self):
     69         return len(self.images)
     70 
     71     def __getitem__(self, idx):
     72         # idx是[0-len(self.images]
     73         # self.images,self.label
     74         # img:pokemeon\mew\0001.jpg(这是一个路径)要转变成img数据
     75         # label:是数字
     76         img, label = self.images[idx], self.labels[idx]
     77         tf = transforms.Compose([
     78             lambda x:Image.open(x).convert('RGB'),# string path -> img data
     79             transforms.Resize(int(self.resize*1.25), int(self.resize*1.25)),
     80             transforms.Randomrotation(15), # 旋转度数
     81             transforms.CenterCrop(self.resize),#中心裁剪,保留resize大小
     82             transforms.ToTensor(),
     83             transforms.Normalize(mean=[0.485,0.456,0.406],
     84                                  std=[0.229,0.224,0.225])  # 归一化之后,范围为-1~1,之前的图片范围为0~1
     85             ])
     86         img = tf(img)  # 将path转换成数据
     87         label = torch.tensor(label)  # 将变量label转换成tensor
     88         return img,label
     89 
     90     def denormalize(self,x_hat):
     91         mean=[0.485,0.456,0.406]
     92         std=[0.229,0.224,0.225]
     93         # x:[c,h,w]
     94         # x_hat = (x-mean)/std
     95         # maen[3]->[3,1,1]
     96         mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
     97         std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
     98         x = x_hat * std+mean
     99         return x
    100 
    101 def main():
    102     import torchvision
    103     vis = visdom.Visdom()
    104     """
    105     如果存储比较规范的话,可以使用下面简单的代码加载数据集,文件夹的标签从0开始编码
    106     tf = transforms.Compose([
    107         transforms.Resize((64,64)),
    108         transforms.ToTensor()
    109     ])
    110     db = torchvision.datasets.ImageFolder('./pokemon',transform=tf)
    111     loader = DataLoader(db,batch_size=32,shuffle=True)
    112     print(db.class_to_idx) #查看类标签
    113     
    114     """
    115     db = Pokemon('./pokemon', 224, 'train') # 根据idx,返回一个
    116     x,y = next(iter(db))
    117     print('sample:',x.shape,y.shape)
    118     #可视化
    119     vis.image(db.denormalize(x),win='sample_x',opts=dict(title = 'sample_x'))
    120     # 加载一批
    121     loader = DataLoader(db,batch_size = 32,shuffle=True,num_workers=8 )
    122     for x,y in loader:
    123         vis.images(db.denormalize(x), nrow=8, win='batch',opts=dict(title='batch'))
    124         vis.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
    125 
    126 
    127 if __name__ == '__main__':
    128     main()

     小结:

    在加载自定义数据集时,一般步骤

    1.定义一个类继承Dataset

    2.在类中读取数据集(图片的路径),重写len函数,和getitem函数

    在len函数中返回数据集的长度

    在getitem函数中,处理一张图片,单个图片路径转换成图片数据(包括transform转换),返回该图片数据和标签

    3,将处理好的数据集(均为张量)放入DataLoader中,进行分批

    loader = DataLoader(db,batch_size = 32,shuffle=True,num_workers=8 )

    4.训练时通过enumerate遍历每个batchsize

  • 相关阅读:
    linux安装篇之mongodb安装及服务自启动配置
    Linux下启动mongodb
    java 实现 图片与byte 数组互相转换
    用java imageio调整图片DPI,例如从96调整为300
    StringRedisTemplate操作redis数据
    Docker 更换国内的Hub源
    2、Docker 基础安装和基础使用 一
    Centos 6.x Openssh 升级 7.7p1 版本
    1、Docker 简介
    2. Python环境安装
  • 原文地址:https://www.cnblogs.com/shuangcao/p/11905505.html
Copyright © 2011-2022 走看看