zoukankan      html  css  js  c++  java
  • pytorch数据读取

    pytorch数据读取机制:

    sampler生成索引index,根据索引从DataSet中获取图片和标签

    1.torch.utils.data.DataLoader

    功能:构建可迭代的数据装在器

    dataset:Dataset类,决定数据从哪读取及如何读取

    batchsize:批大小

    num_works:是否多进程读取数据,当条件允许时,多进程读取数据会加快数据读取速度。

    shuffle:每个epoch是否乱序

    drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

    DataLoader(dataset, batchsize=1, shuffle=False, batch_sampler=None, num_workers=0, collate_fn=None, pin_memeory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)

    epoch:所有训练样本都已输入到模型中,称为一个epoch

    iteration:一批样本输入到模型中,称为一个iteration

    batchsize:批大小,决定一个epoch有多少个iteration

    例如:

    样本总数:80, batchsize:8

    1epoch = 10 iteraion

    样本总数:87, batchsize:8

    1 epoch = 10 iteration drop_last=True

    1 epoch = 11 iteration drop_last=False

    2.torch.utils.data.Dataset

    功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写

    __getitem__()

    getitem:接收一个索引,返回一个样本

    class Dataset(object):
        def __getitem__(self, index):
            raise NotImplementedError
        
        def __add__(self, other):
            return ConcatDataset([self, other])

    人命币分类实例:

    数据分割:

    import os
    import random
    import shutil
    
    
    def makedir(new_dir):
        if not os.path.exists(new_dir):
            os.makedirs(new_dir)
    
    
    if __name__ == '__main__':
    
        random.seed(1)
    
        dataset_dir = os.path.join("..", "..", "data", "RMB_data")
        split_dir = os.path.join("..", "..", "data", "rmb_split")
        train_dir = os.path.join(split_dir, "train")
        valid_dir = os.path.join(split_dir, "valid")
        test_dir = os.path.join(split_dir, "test")
    
        train_pct = 0.8
        valid_pct = 0.1
        test_pct = 0.1
    
        for root, dirs, files in os.walk(dataset_dir):
            for sub_dir in dirs:
    
                imgs = os.listdir(os.path.join(root, sub_dir))
                imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
                random.shuffle(imgs)
                img_count = len(imgs)
    
                train_point = int(img_count * train_pct)
                valid_point = int(img_count * (train_pct + valid_pct))
    
                for i in range(img_count):
                    if i < train_point:
                        out_dir = os.path.join(train_dir, sub_dir)
                    elif i < valid_point:
                        out_dir = os.path.join(valid_dir, sub_dir)
                    else:
                        out_dir = os.path.join(test_dir, sub_dir)
    
                    makedir(out_dir)
    
                    target_path = os.path.join(out_dir, imgs[i])
                    src_path = os.path.join(dataset_dir, sub_dir, imgs[i])
    
                    shutil.copy(src_path, target_path)
    
                print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
                                                                     img_count-valid_point))

    创建Dataset

    import os
    import random
    from PIL import Image
    from torch.utils.data import Dataset
    
    random.seed(1)
    rmb_label = {"1": 0, "100": 1}
    
    
    class RMBDataset(Dataset):
        def __init__(self, data_dir, transform=None):
            """
            rmb面额分类任务的Dataset
            :param data_dir: str, 数据集所在路径
            :param transform: torch.transform,数据预处理
            """
            self.label_name = {"1": 0, "100": 1}
            self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
            self.transform = transform
    
        def __getitem__(self, index):
            path_img, label = self.data_info[index]
            img = Image.open(path_img).convert('RGB')     # 0~255
    
            if self.transform is not None:
                img = self.transform(img)   # 在这里做transform,转为tensor等等
    
            return img, label
    
        def __len__(self):
            return len(self.data_info)
    
        @staticmethod
        def get_img_info(data_dir):
            data_info = list()
            for root, dirs, _ in os.walk(data_dir):
                # 遍历类别
                for sub_dir in dirs:
                    img_names = os.listdir(os.path.join(root, sub_dir))
                    img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
    
                    # 遍历图片
                    for i in range(len(img_names)):
                        img_name = img_names[i]
                        path_img = os.path.join(root, sub_dir, img_name)
                        label = rmb_label[sub_dir]
                        data_info.append((path_img, int(label)))
    
            return data_info

    3.transforms

    torch.transforms:常用图像处理方法

    数据中心化  数据标准化  缩放  裁剪  旋转  翻转  填充  噪声添加  灰度转换  线性变换  仿射变换  亮度、饱和度及对比度

  • 相关阅读:
    Linux HugePages及MySQL 大页配置
    tcp短连接TIME_WAIT问题解决方法大全
    从问题看本质: 研究TCP close_wait的内幕
    tcp_tw_recycle和tcp_timestamps的文章汇总
    MYSQL博客
    Tcp_tw_reuse、tcp_tw_recycle 使用场景及注意事项
    net.ipv4.tcp_tw_recycle
    TIME-WAIT和CLOSE-WAIT
    mysql 源码调试方法
    mysqldump 备份原理9
  • 原文地址:https://www.cnblogs.com/haiboxiaobai/p/11749379.html
Copyright © 2011-2022 走看看