zoukankan      html  css  js  c++  java
  • little_by_little_2 为一个数据集创建一个dataset类。(基于pytorch)

    little_by_little_2 为一个数据集创建一个dataset类。(基于pytorch)

    前言

    最近一段时间陷入了焦虑,迷茫之中最终获得了救赎。不想提及。

    任务

    为一个分类100元和1元的数据集创建一个pytorch.dataset,以便dataloader来读取

    源代码

    import os
    import random
    from PIL import Image
    from torch.utils.data import Dataset
    
    random.seed(1)
    rmb_label = {"1": 0, "100": 1}
    
    #1
    class RMBDataset(Dataset):
        def __init__(self, data_dir, transform=None):
            """
            rmb面额分类任务的Dataset
            :param data_dir: str, 数据集所在路径
            :param transform: torch.transform,数据预处理
            """
            self.label_name = {"1": 0, "100": 1}
            self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
            self.transform = transform
    #2
        def __getitem__(self, index):
            path_img, label = self.data_info[index]
            img = Image.open(path_img).convert('RGB')     # 0~255
    
            if self.transform is not None:
                img = self.transform(img)   # 在这里做transform,转为tensor等等
    
            return img, label
    
        def __len__(self):
            return len(self.data_info)
    #3
        @staticmethod
        def get_img_info(data_dir):
            data_info = list()
            for root, dirs, _ in os.walk(data_dir):
                # 遍历类别
                for sub_dir in dirs:
                    img_names = os.listdir(os.path.join(root, sub_dir))
                    img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
    
                    # 遍历图片
                    for i in range(len(img_names)):
                        img_name = img_names[i]
                        path_img = os.path.join(root, sub_dir, img_name)
                        label = rmb_label[sub_dir]
                        data_info.append((path_img, int(label)))
    
            return data_info
    
    

    解读

    #1部分

    class RMBDataset(Dataset):
        def __init__(self, data_dir, transform=None):
            """
            rmb面额分类任务的Dataset
            :param data_dir: str, 数据集所在路径
            :param transform: torch.transform,数据预处理
            """
            self.label_name = {"1": 0, "100": 1}
            self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
            self.transform = transform
    

    初始化数据不多作赘述。

    #2部分

        def __getitem__(self, index):
            path_img, label = self.data_info[index]
            img = Image.open(path_img).convert('RGB')     # 0~255
    
            if self.transform is not None:
                img = self.transform(img)   # 在这里做transform,转为tensor等等
    
            return img, label
    
        def __len__(self):
            return len(self.data_info)
    
    • 为什么要在_get_ item 里面定义?因为pytorch中用dataloader类调用dataset类的时候是这样子的:

    • path_img, label = self.data_info[index] 接收数据的数据以及标签

    • img = Image.open(path_img).convert('RGB') # 0~255 将img转换成三通道模式

    •     if self.transform is not None:
              img = self.transform(img)   # 在这里做transform,转为tensor等等
      

    判断是否传入了transform,若传入了transform则进行transform.compounds里面的transform变换.

    • return img, label 返回数据及标签

    # 3 部分

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
    
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))
    
        return data_info
    
    • 此函数的作用,得到路径内所有图片的数据,并打上label
    • for root, dirs, _ in os.walk(data_dir): 此处涉及到os.walk函数,
    def walk(top: T,
     topdown: bool = True,
     onerror: Optional[(Exception) -> None] = None,
     followlinks: bool = False) -> Iterator[Tuple[T, List[T], List[T]]]
     top -- 是你所要遍历的目录的地址, 
     return--返回的是一个三元组(root,dirs,files)。
    
        root 所指的是当前正在遍历的这个文件夹的本身的地址
        dirs 是一个 list ,内容是该文件夹中所有的目录的名字(不包括子目录)
        files 同样是 list , 内容是该文件夹中所有的文件(不包括子目录)
                                       
    topdown --可选,为 True,则优先遍历 top 目录,否则优先遍历 top 的子目录(默认为开启)。如果 topdown 参数为 True,walk 会遍历top文件夹,与top 文件夹中每一个子目录。
    
    onerror -- 可选,需要一个 callable 对象,当 walk 需要异常时,会调用。
    
    followlinks -- 可选,如果为 True,则会遍历目录下的快捷方式(linux 下是软连接 symbolic link )实际所指的目录(默认关闭),如果为 False,则优先遍历 top 的子目录。
    
    • for sub_dir in dirs:
                  img_names = os.listdir(os.path.join(root, sub_dir))
                  img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
      

      先解释一下目录结构:

    image-20200416133247390

    1和100里面放着1和100元的图片.

    img_names = os.listdir(os.path.join(root, sub_dir)) 提取出.../1

    img_names = list(filter(lambda x: x.endswith('.jpg'), img_names)) 将.../1 下面所有的以.jpg结尾的文件名提取出来返回一个list,也就是说此时img_names成了一个list里面装满了.../1目录下所有的图片名字

    •         for i in range(len(img_names)):
                  img_name = img_names[i]
                  path_img = os.path.join(root, sub_dir, img_name)
                  label = rmb_label[sub_dir]
                  data_info.append((path_img, int(label)))
      

      这个函数主要作用是提取出img_names里面所有图片的路径以及label其中值得一提的是label = rmb_label[sub_dir] 由于本身文件夹的名字就是label所以提取label的方法就是提取文件夹的名字.

      最后返回一个data_info list 里面每个元素为元组形式(img_path,label).

  • 相关阅读:
    修复 Visual Studio Error “No exports were found that match the constraint”
    RabbitMQ Config
    Entity Framework Extended Library
    Navisworks API 简单二次开发 (自定义工具条)
    NavisWorks Api 简单使用与Gantt
    SQL SERVER 竖表变成横表
    SQL SERVER 多数据导入
    Devexpress GridControl.Export
    mongo DB for C#
    Devexress XPO xpPageSelector 使用
  • 原文地址:https://www.cnblogs.com/negu/p/12712337.html
Copyright © 2011-2022 走看看