zoukankan      html  css  js  c++  java
  • 0701-数据处理

    0701-数据处理

    pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html

    一、概述

    在机器学习中,尤其是在深度学习中,需要耗费大量的精力去处理数据,并且数据的处理对训练神经网络来说也是很重要的,良好的数据不仅会加速模型的训练,也可以提高模型的效率。

    为此,torch 提供了几个高效便捷的工具,以便使用者更方便的对数据做处理,同时也可以并行化加速数据加载。

    二、加载自定义数据集

    在 torch 中,可以加载自定义数据集,在这个过程中,需要自定义数据集对象,数据集对象将被抽象为 Dataset 类,也就是说实现自定义的数据集需要继承 Dataset,同时也需要实现两个 Python 魔法方法:

    • __getiter__:返回一条数据或一个样本。obj[index] 等价于 obj.__getitem__(index)
    • __len__:返回样本的数量。len(obj) 等价于 obj.__len__()

    在这里我们以 Kaggle 经典挑战赛“Dogs vs. Cat”的数据为例,详细讲解如何处理数据。其中该数据是一个分类问题的数据,判断一张图片是狗还是猫,它的所有图片都放在一个文件夹下,并可以根据文件名的前缀是狗还是猫。需要图片数据的可以加我微信:chenyoudea

    import os
    
    imgs = os.listdir('./img/dogcat')  # 获取./img/dogcat下的所有图片文件
    for img in imgs:
        print(img)
    
    dog.12497.jpg
    cat.12484.jpg
    cat.12485.jpg
    dog.12496.jpg
    cat.12487.jpg
    cat.12486.jpg
    dog.12498.jpg
    dog.12499.jpg
    
    import os
    import torch as t
    import numpy as np
    from PIL import Image
    from torch.utils import data
    
    
    class DogCat(data.Dataset):
        def __init__(self, root):
            imgs = os.listdir(root)
            # 所有图片的绝对路径
            # 这里不实际加载图片,只是指定路径
            # 当调用__getitem__时才会真正读图片
            self.imgs = [os.path.join(root, img) for img in imgs]
    
        def __getitem__(self, index):
            img_path = self.imgs[index]
            # dog->1, cat->0
            label = 1 if 'dog' in img_path.split(
                '/')[-1] else 0  # 通过对图片文件名前缀的判断给图片增加标签
            pil_img = Image.open(img_path)  # 打开图片
            array = np.asarray(pil_img)  # 把图片转为 ndarray 数据
            data = t.from_numpy(array)  # 把图片转为 Tensor 数据
            return data, label
    
    
    dataset = DogCat('./img/dogcat/')
    # img, label = dataset[0]  # 相当于调用 dataset.__getitem__(0)
    for img, label in dataset:
        print(img.size(), img.float().mean(), label)
    
    torch.Size([375, 499, 3]) tensor(150.5080) 1
    torch.Size([500, 497, 3]) tensor(106.4915) 0
    torch.Size([499, 379, 3]) tensor(171.8085) 0
    torch.Size([375, 499, 3]) tensor(116.8139) 1
    torch.Size([374, 499, 3]) tensor(115.5177) 0
    torch.Size([236, 289, 3]) tensor(130.3004) 0
    torch.Size([377, 499, 3]) tensor(151.7174) 1
    torch.Size([400, 300, 3]) tensor(128.1550) 1
    
    
    /Applications/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:23: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:143.)
    

    上述所示的 /Applications/anaconda3/lib…… 的错误,是因为图片是 git 上拿下来的,没有修改权限,我懒得修改了,自己有空把它修改下,反正没啥影响。

    对于我们自定义的数据集,我们已经学会了如何通过代码定义这样的数据集,但是这样的数据并不适合使用,因为它们有两个这样的问题:

    1. 每张图片的大小不一样,这对于需要取 batch 训练的神经网络来说并不友好
    2. 返回样本的数值较大,没有归一化到 [-1,1]

    三、利用 torchvision 工具处理数据集

    为了解决上一节的遗留的问题,torch 提供了 torchvision,它是一个视觉工具包,提供了很多视觉图像处理的工具,其中 transform 模块提供了对 PIL Image 对象和 Tensor 对象的常用操作。如果想更详细的了解这个工具,可以去去查看官方文档:https://github.com/pytorch/vision/

    对 PIL Image 的常见操作如下:

    • Resize:调整图片尺寸
    • CenterCrop、RandomCrop、RandomSizedCrop:裁剪图片
    • Pad:填充
    • ToTensor:把 PIL Image 对象转成 Tensor,会自动将 [0,255] 归一化为 [0,1]

    对 Tensor 的常见操作如下:

    • Normalize:标准化,即减均值,除以标准差
    • ToPILImage:将 Tensor 转为 PIL Image 对象

    如果需要对图片进行多个操作,可以通过 Compose 把这些操作拼接起来,类似于 nn.Sequential。需要注意的是,这些操作定义后是以对象的形式存在,真正使用时需要调用它的 __call__ 方法,类似于 nn.Module

    例如,如果要把图片调整为 224*224,首先构建操作 trans = Scale((224,224)),然后调用 trans(img)。接下来我们就用 transform 的这些操作来优化上面实现的 dataset。

    import os
    from PIL import Image
    import numpy as np
    from torchvision import transforms as T
    
    transform = T.Compose([
        T.Resize(224),  # 缩放图片,保持长宽比不变,最短边为 224 像素
        T.CenterCrop(224),  # 从图片中间切出 224*224 的图片
        T.ToTensor(),  # 把图片转成 Tensor,归一化至 [0,1]
        T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])  # 标准化至 [-1,1]
    ])
    
    
    class DogCat(data.Dataset):
        def __init__(self, root, transforms=None):
            imgs = os.listdir(root)
            self.imgs = [os.path.join(root, img) for img in imgs]  # 拼接图片路径
            self.transforms = transforms  # 作为图片是否进行处理的标志
    
        def __getitem__(self, index):
            img_path = self.imgs[index]
            label = 0 if 'dog' in img_path.split('/')[-1] else 1
            data = Image.open(img_path)
            if self.transforms:  # 判断图片是否需要进行处理
                data = self.transforms(data)
            return data, label
    
        def __len__(self):
            return len(self.imgs)
    
    
    dataset = DogCat('./img/dogcat/', transforms=transform)
    img, label = dataset[0]
    for img, label in dataset:
        print(img.size(), label)
    
    torch.Size([3, 224, 224]) 0
    torch.Size([3, 224, 224]) 1
    torch.Size([3, 224, 224]) 1
    torch.Size([3, 224, 224]) 0
    torch.Size([3, 224, 224]) 1
    torch.Size([3, 224, 224]) 1
    torch.Size([3, 224, 224]) 0
    torch.Size([3, 224, 224]) 0
    

    从上述代码可以看到 transforms 的强大,除了这些,transforms 还可以通过 Lambda 封装自定义的转换策略。

    例如,如果相对 PIL Image 进行随机旋转,则可以写成 trans = T.Lambda(lambda img: img.rotate(random()*360))

    上面我们说到了如何加载自定义的数据集,对于很多研究者来说,只是想试验自己的算法有没有问题,如果自己去获取数据,再加上深度学习对数据量的要求,那是非常困难的。

    为此 torchvision 预先实现了常用的 Dataset,包括 CIFAR-10、ImageNet、COCO、MNIST、LSUN 等数据集,可以通过调用 torchvision.datasets 下相应的对象来调用相关的数据集,具体的使用方法可以查看官方文档:https://pytorch.org/vision/stable/datasets.html

    四、ImageFolder 的使用——处理数据集

    本节介绍一个我们经常会用到的一个 Dataset——ImageFolder,它的实现和上述 DogCat类 的功能类似,主要是对图片进行处理。

    ImageFoder 假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,它的构造函数如下所示:ImageFolder(root, transform=None, target_transform=None, loader=default_loader)

    它主要有以下四个参数:

    • root:在 root 指定的路径下寻找图片
    • transform:对 PIL Image进行转换操作,transform 的输入是使用 loader 读取图片的返回对象
    • target_transform:对 label 的转换
    • loader:指定加载图片的函数,默认操作是读取为 PIL Image 对象

    label 是按照文件夹名字顺序排序后存成字典的,即 {类名:类序号(从 0 开始)},一般来说最好直接将文件命名为从 0 开始的数字,这样回合 ImageFolder 实际的 label 一致。

    from torchvision.datasets import ImageFolder
    
    dataset = ImageFolder('./img/dogcat_2')
    
    # cat 文件夹的图片对应 label 0,dog 对应 1
    dataset.class_to_idx
    
    {'cat': 0, 'dog': 1}
    
    # 所有图片的路径和对应的 label
    dataset.imgs
    
    [('./img/dogcat_2/cat/cat.12484.jpg', 0),
     ('./img/dogcat_2/cat/cat.12485.jpg', 0),
     ('./img/dogcat_2/cat/cat.12486.jpg', 0),
     ('./img/dogcat_2/cat/cat.12487.jpg', 0),
     ('./img/dogcat_2/dog/dog.12496.jpg', 1),
     ('./img/dogcat_2/dog/dog.12497.jpg', 1),
     ('./img/dogcat_2/dog/dog.12498.jpg', 1),
     ('./img/dogcat_2/dog/dog.12499.jpg', 1)]
    
    dataset[0][1]  # 第一维是第几张图,第二维为 1 返回 label
    
    0
    
    # 没有任何的 transform,多以返回的还是 PIL Image 对象
    dataset[0][0]  # 为 0 返回图片数据,返回的 Image 对象如下图所示
    

    # 加上 transform
    normalize = T.Normalize(mean=[0.4, 0.4, 0.4], std=[0.2, 0.2, 0.2])
    transform = T.Compose([
        T.RandomResizedCrop(224),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        normalize,
    ])
    dataset = ImageFolder('img/dogcat_2', transform=transform)
    dataset[0][0].size()  # 深度学习图片数据一般保存成 C*H*W,即 通道数*图片高*图片宽
    
    torch.Size([3, 224, 224])
    
    to_img = T.ToPILImage()
    # 0.2 和 0.4 是标准差和均值的近似
    to_img(dataset[0][0]*0.2+0.4) # 程序输出如下图所示
    

    五、DataLoader 的使用——批加载数据

    Dataset 只负责抽象数据,并且一次调用 __getitem__ 只返回一个样本。

    在训练神经网络的时候,是对一个 batch 的数据进行操作,同时还需要对数据进行 shuffle 和并行加速等,为此,torch 提供了 DataLoader 去实现这些功能。

    DataLoader 的函数定义如下:

    DataLoader(dataset,
               batch_size=1,
               shuffle=False,
               sampler=None,
               num_workers=0,
               collate_fn=default_collate,
               pin_memory=False,
               drop_last=False)
    
    • dataset:加载的数据集(Dataset 对象)
    • batch_size:batch size(批大小)
    • shuffle:是否把数据打乱
    • sampler:样本抽样,后面会详细解释
    • num_workers:使用多进程加载的进程数,0 表示不使用多进程
    • collate_fn:如何把多个数据拼接成一个 batch,一般使用默认的方式就可以了
    • pin_memory:是否将数据保存在 pin memory 区,pin memory 中的数据转到 GPU 中速度会快一些
    • drop_last:dataset 中的数据个数可能不是 batch_size 的整数倍,drop_last 为 True,会把多出来不足一个 Batch 的数据丢弃
    from torch.utils.data import DataLoader
    
    dataloader = DataLoader(dataset,
                            batch_size=3,
                            shuffle=True,
                            num_workers=0,
                            drop_last=False)
    
    dataiter = iter(dataloader)  # dataloader是一个可迭代对象,通过 iter 把 dataloader 变成一个迭代器
    imgs, labels = next(dataiter)
    imgs.size()  # batch_size,channel,height,weight
    
    torch.Size([3, 3, 224, 224])
    

    dataloader 是一个可迭代的对象,因此可以像使用迭代器一样使用它。迭代器如果你忘记了是啥,可以看这篇文章:迭代器

    # 迭代器的两种使用方法
    # 第一种直接获取所有数据,数据量大不建议使用
    for batch_datas, batch_labels in dataloader:
        train()
    
    # 第二种只生成一个迭代器,用一个取一个数据
    dataiter = iter(dataloader)
    imgs, labels = next(dataiter)
    

    六、处理损坏图片

    class NewDogCat(DogCat):
        def __getitem__(self, index):
            try:
                # 调用父类的获取函数,相当于 DogCat.__getitem__(self,index)
                return super(NewDogCat, self).__getitem__(index)
            except:
                return None, None  # 获取异常的对象返回 None
    
    
    from torch.utils.data.dataloader import default_collate  # 导入默认的拼接方式
    
    
    def my_collate_fn(batch):
        """
        batch 中每个元素形如(data,label)
        """
        batch = list(filter(lambda x: x[0] is not None, batch))  # 过滤为 None 的数据
        return default_collate(batch)  # 用默认方式拼接过滤后的 batch 数据
    
    
    dataset = NewDogCat('img/dogcat_wrong/', transforms=transform)
    dataset[6]
    
    (None, None)
    
    dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=0)
    for batch_datas, batch_labels in dataloader:
        print(batch_datas.size(), batch_labels.size())
    
    torch.Size([2, 3, 224, 224]) torch.Size([2])
    torch.Size([2, 3, 224, 224]) torch.Size([2])
    torch.Size([2, 3, 224, 224]) torch.Size([2])
    torch.Size([1, 3, 224, 224]) torch.Size([1])
    torch.Size([1, 3, 224, 224]) torch.Size([1])
    

    通过查看上面的打印结果,可以看到第 4 个 batch_size 为 1,这是因为其中有一张图片损坏,而最后一个 batch_size 也是 1,是因为总共有 9 张图片,无法整除 2,因此最后一个 batch 的数据会少于 batch_size,可以通过指定 drop_last=True 丢弃最后一个样本数目不足 batch_size 的 batch。

    除了上述所说的方法,对于损坏或数据集加载异常等情况,还可以通过其他方法解决,例如遇到异常图片,就可以随机选择另外一张图片代替,则 batch_size 就不会小于规定的 batch_size。

    class NewDogCat(DogCat):
        def __getitem__(self, index):
            try:
                return super(NewDogCat, self).__getitem__(index)
            except:
                new_index = random.randint(0, len(self) - 1)
                return self[new_index]
    

    上述所说的方法看起来很好,但是如果我们换个角度去想,我为什么要让文件夹里面有一张异常的图片呢?因此为了防止图片异常,更应该对数据进行彻底清洗。

    DataLoader 为了实现多进程加速,它封装了 Python 的标准库 multiprocessing,因此在 Dataset 和 DataLoader 使用时有以下两个建议:

    1. 高负载的操作放在 __getitem__中,如加载图片等
    2. dataset 中应该尽量只包含只读对象,避免修改任何可变对象

    第一点是因为多进程会并行地调用 __getitem__ 函数,把负载高的放在 __getitem__ 函数中能够实现并行加速。

    第二点是因为 dataloader 使用多进程加载,如果在 Dataset 中使用了可变对象,可能会有意想不到的冲突。在多线程/多进程中,修改一个可变对象需要加锁,但是 dataloader 的设计让它很难加锁,因此最好避免在 dataset 中修改可变对象。

    下面就是一个不好的例子,在多进程中处理的 self.num 可能和预期不符,这种问题不会报错,所以很难发现。如果真的一定要修改可变对象,可以使用 Python 标准库 Queue 中的相关数据结构。

    class BadDataset(data.Dataset):
        def __init__(self):
            self.datas = range(10)
            self.num = 0  # 取数据的次数
    
        def __getitem__(self, index):
            self.num += 1
            return self.datas[index]
    

    使用 Python 的 multiprocessing 库的另一个问题就是,在使用多进程时,如果主程序异常终止,相应的数据加载进程可能无法正常退出。这个时候你可能会发现程序已经退出了,但是 GPU 显存和内存仍然被占用着,这个时候就需要手动强行终止进程。

    七、数据采样

    torch 中还单独提供了一个 sampler 模块,用来进行数据采样。常用的有随机采样器 RandomSampler,当 dataloader 的 shuffle 参数为 True 时,系统就会自动调用这个采样器,进而打乱数据。

    默认的采样器是 SequentialSampler,它会按顺序一个一个进行采样。

    在这里介绍另外一个很有用的采样方法 WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可以用它进行重采样。

    构建 WeightedRandomSampler 时需要提供3个参数:

    • 每个样本的权重weights
    • 共选取的样本总数 num_samples
    • 可选参数 replacement,指定是否可以重复选取一个样本,默认为 True,也就是说允许一个 epoch 中重复采样一个数据。如果设置为 False,则当某一类样本被全部选取结束后,它的样本还没有达到 num_samples 时,sampler 将不会再从该类中选择数据,此时可能会导致 weights 参数失效

    注:权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。

    dataset = DogCat('./img/dogcat/', transforms=transform)
    # 狗的图片被取出的概率是猫的概率的两倍
    # 两类图片被取出的概率和 weights 的绝对大小无关,只和比值有关,例如这里的比值为 2:1
    weights = [2 if label == 1 else 1 for data, label in dataset]
    weights
    
    [1, 2, 2, 1, 2, 2, 1, 1]
    
    from torch.utils.data.sampler import WeightedRandomSampler
    
    sampler = WeightedRandomSampler(weights, num_samples=9, replacement=True)
    dataloader = DataLoader(dataset, batch_size=3, sampler=sampler)
    
    for datas, labels in dataloader:
        print(labels.tolist())
    
    [1, 1, 1]
    [1, 0, 0]
    [1, 0, 1]
    

    从上面可以看到猫狗样本的比例约为 1:2,另外一共只有 8 个样本,却返回了 9 个,说明有样本被重复返回,这就是 replacement 参数的左右,下面我们把 replacement 设为 False。

    # 如果 weights 设定为 100:1,则 猫 的被选中的概率几乎为 0
    weights = [100 if label == 1 else 1 for data, label in dataset]
    
    sampler = WeightedRandomSampler(weights, num_samples=9, replacement=True)
    dataloader = DataLoader(dataset, batch_size=3, sampler=sampler)
    
    for datas, labels in dataloader:
        print(labels.tolist())
    
    [1, 1, 1]
    [1, 1, 1]
    [1, 1, 1]
    
    sampler = WeightedRandomSampler(weights, 8, replacement=False)
    dataloader = DataLoader(dataset, batch_size=4, sampler=sampler)
    for datas, labels in dataloader:
        print(labels.tolist())
    
    [1, 1, 1, 1]
    [0, 0, 0, 0]
    

    从上面的代码可以看到,num_samples 等于 dataset 的样本总数,为了不重复选取,sampler 会把每个样本都返回,这样就失去了 weight 参数的意义。

    从上面的例子可以看出 sampler 在样本采样中的作用:如果指定了 sampler,shuffle 将不会再生效,并且 sampler.num_samples 会覆盖 dataset 的实际大小,也就是一个 epoch 返回的图片总数取决于 sampler.num_samples。

  • 相关阅读:
    win7下安装、使用jBuiler2006
    c#:使用using关键字自动释放资源未必一定就会有明显好处
    silverlight:ScrollViewer的各种高度研究
    silverlight:对象拖动的优雅解决方案
    民航货运英文缩写
    "RDLC"报表参数传递及主从报表
    "RDLC报表"速成指南
    打印常识:A4纸张在显示器上应该要多少像素?
    Silverlight:获取ContentTemplate中的命名控件
    Silverlight:双向绑定综合应用多集合的依赖绑定
  • 原文地址:https://www.cnblogs.com/nickchen121/p/14708224.html
Copyright © 2011-2022 走看看