zoukankan      html  css  js  c++  java
  • pytorch数据读取

    pytorch数据读取机制:

    sampler生成索引index,根据索引从DataSet中获取图片和标签

    1.torch.utils.data.DataLoader

    功能:构建可迭代的数据装在器

    dataset:Dataset类,决定数据从哪读取及如何读取

    batchsize:批大小

    num_works:是否多进程读取数据,当条件允许时,多进程读取数据会加快数据读取速度。

    shuffle:每个epoch是否乱序

    drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

    DataLoader(dataset, batchsize=1, shuffle=False, batch_sampler=None, num_workers=0, collate_fn=None, pin_memeory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

    epoch:所有训练样本都已输入到模型中,称为一个epoch

    iteration:一批样本输入到模型中,称为一个iteration

    batchsize:批大小,决定一个epoch有多少个iteration

    例如:

    样本总数:80, batchsize:8

    1epoch = 10 iteraion

    样本总数:87, batchsize:8

    1 epoch = 10 iteration drop_last=True

    1 epoch = 11 iteration drop_last=False

    2.torch.utils.data.Dataset

    功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写

    __getitem__()

    getitem:接收一个索引,返回一个样本

    class Dataset(object):
        def __getitem__(self, index):
            raise NotImplementedError
        
        def __add__(self, other):
            return ConcatDataset([self, other])

    人命币分类实例:

    数据分割:

    import os
    import random
    import shutil
    
    
    def makedir(new_dir):
        if not os.path.exists(new_dir):
            os.makedirs(new_dir)
    
    
    if __name__ == '__main__':
    
        random.seed(1)
    
        dataset_dir = os.path.join("..", "..", "data", "RMB_data")
        split_dir = os.path.join("..", "..", "data", "rmb_split")
        train_dir = os.path.join(split_dir, "train")
        valid_dir = os.path.join(split_dir, "valid")
        test_dir = os.path.join(split_dir, "test")
    
        train_pct = 0.8
        valid_pct = 0.1
        test_pct = 0.1
    
        for root, dirs, files in os.walk(dataset_dir):
            for sub_dir in dirs:
    
                imgs = os.listdir(os.path.join(root, sub_dir))
                imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
                random.shuffle(imgs)
                img_count = len(imgs)
    
                train_point = int(img_count * train_pct)
                valid_point = int(img_count * (train_pct + valid_pct))
    
                for i in range(img_count):
                    if i < train_point:
                        out_dir = os.path.join(train_dir, sub_dir)
                    elif i < valid_point:
                        out_dir = os.path.join(valid_dir, sub_dir)
                    else:
                        out_dir = os.path.join(test_dir, sub_dir)
    
                    makedir(out_dir)
    
                    target_path = os.path.join(out_dir, imgs[i])
                    src_path = os.path.join(dataset_dir, sub_dir, imgs[i])
    
                    shutil.copy(src_path, target_path)
    
                print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
                                                                     img_count-valid_point))

    创建Dataset

    import os
    import random
    from PIL import Image
    from torch.utils.data import Dataset
    
    random.seed(1)
    rmb_label = {"1": 0, "100": 1}
    
    
    class RMBDataset(Dataset):
        def __init__(self, data_dir, transform=None):
            """
            rmb面额分类任务的Dataset
            :param data_dir: str, 数据集所在路径
            :param transform: torch.transform,数据预处理
            """
            self.label_name = {"1": 0, "100": 1}
            self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
            self.transform = transform
    
        def __getitem__(self, index):
            path_img, label = self.data_info[index]
            img = Image.open(path_img).convert('RGB')     # 0~255
    
            if self.transform is not None:
                img = self.transform(img)   # 在这里做transform,转为tensor等等
    
            return img, label
    
        def __len__(self):
            return len(self.data_info)
    
        @staticmethod
        def get_img_info(data_dir):
            data_info = list()
            for root, dirs, _ in os.walk(data_dir):
                # 遍历类别
                for sub_dir in dirs:
                    img_names = os.listdir(os.path.join(root, sub_dir))
                    img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
    
                    # 遍历图片
                    for i in range(len(img_names)):
                        img_name = img_names[i]
                        path_img = os.path.join(root, sub_dir, img_name)
                        label = rmb_label[sub_dir]
                        data_info.append((path_img, int(label)))
    
            return data_info

    3.transforms

    torch.transforms:常用图像处理方法

    数据中心化  数据标准化  缩放  裁剪  旋转  翻转  填充  噪声添加  灰度转换  线性变换  仿射变换  亮度、饱和度及对比度

  • 相关阅读:
    WCF 第八章 安全 确定替代身份(中)使用AzMan认证
    WCF 第八章 安全 总结
    WCF 第八章 安全 因特网上的安全服务(下) 其他认证模式
    WCF Membership Provider
    WCF 第八章 安全 确定替代身份(下)模仿用户
    WCF 第八章 安全 因特网上的安全服务(上)
    WCF 第九章 诊断
    HTTPS的七个误解(转载)
    WCF 第八章 安全 日志和审计
    基于比较的排序算法集
  • 原文地址:https://www.cnblogs.com/haiboxiaobai/p/11749379.html
Copyright © 2011-2022 走看看