zoukankan      html  css  js  c++  java
  • gluon模块进行数据加载-Dataset和DataLoader

    在gluon接口中,通过Dataset和DataLoader来对数据集进行循环遍历,并返回batch大小的数据,其中Dataset对象用于数据的收集、加载和变换,而DataLoader对象用于返回batch大小的数据。

    1. 相关模块
    mxnet.gluon.data : 数据加载API
    mxnet.gluon.data.vision : 专门用于计算机视觉的数据集API和处理工具
    2. Dataset介绍
    Dataset对象用于收集数据、加载和变换数据,其中 数据加载是能够通过给定的下标获取对应的样本,数据变换可以对数据进行各种数据增广操作

    原理
    所有Dataset类中,都有以下四个方法:

    getitem(idx): 数据加载,用于返回第idx个样本
    len(): 用于返回数据集的样本的数量
    transform(fn, lazy = True): 数据变换,用于返回对每个样本利用fn函数进行数据变换(增广)后的Dataset
    transform_first(fn, lazy = True): 数据变换,用于返回对每个样本的特征利用fn函数进行数据变换(增广)后的Dataset,而不对label进行数据增广
    用法
    ArrayDataset

    import mxnet as mx
    ## 定义
    mx.random.seed(42) # 固定随机数种子,以便能够复现
    X = mx.random.uniform(shape = (10, 3))
    y = mx.random.uniform(shape = (10, 1))
    dataset = mx.gluon.data.ArrayDataset(X, y) # ArrayDataset不需要从硬盘上加载数据
    ## 使用
    dataset[5]  # 将返回第6个样本的特征和标签,(特征,标签)

    ImageFolderDataset

    import mxnet as mx
    ## 定义
    dataset= gdata.vision.ImageFolderDataset("样本集的根路径", flag=1)
    ## 使用
    dataset[5]  # 将返回第6个样本的特征和标签,(特征,标签)

    ImageRecordDataset

    import mxnet as mx
    ## 定义
    file = '/xxx/train.rec'
     #不需要指定idx文件路径,会从路径中自动拼接处idx的路径,例如此处为/xxx/train.idx
    dataset= gdata.vision.ImageRecordDataset(file) 
    ## 使用
    dataset[5]  # 将返回第6个样本的特征和标签,(特征,标签)

    API中的所有Dataset


    mxnet.gluon.data.Dataset: 抽象的数据集类

    mxnet.gluon.data.ArrayDataset: 组合多个Dataset的数据集类

    mxnet.gluon.data.RecordFileDataset: .rec文件的数据集类

    mxnet.gluon.data.vison.MNIST: MNIST数据集的Dataset

    mxnet.gluon.data.vison.FashionMNIST: FashionMNIST数据集的Dataset

    mxnet.gluon.data.vison.CIFAR10: CIFAR10数据集的Dataset

    mxnet.gluon.data.vison.CIFAR100: CIFAR100数据集的Dataset

    mxnet.gluon.data.vison.ImageRecordDataset: 含有图片的.rec文件的Dataset

    mxnet.gluon.data.vison.ImageFolderDataset: 存储图片在文件夹结构的Dataset

    说明: mxnet和numpy的array可以直接作为Dataset

    3. DataLoader介绍
    加载Dataset,迭代时返回batch大小的样本
    可以方便的并行地加载数据
    使用示例如下:

    from multiprocessing import cpu_count
    CPU_COUNT = cpu_count()
    data_loader = mx.gluon.data.DataLoader(dataset, batch_size = 5, num_workers = CPU_COUNT)
    for X, y in data_loader:
        print X.shape, y.shape

    4. transforms模块介绍
    在gloun的data接口中,有可以使用的数据增广的模块(mxnet.gluon.data.vision.tranforms)。在transforms模块中定义了很多数据变换的layer(为Block的子类),变换layer的输入为样本,输出为变换后的样本。

    用法示例

    from mxnet.gluon import data as gdata
    train_ds = gdata.vision.ImageFolderDataset("样本集的根路径", flag=1)
    print train_ds[0] #变换之前的数据
    ## 数据变换定义
    transform_train = gdata.vision.transforms.Compose([  # Compose将这些变换按照顺序连接起来
            # 将图片放大成高和宽各为 40 像素的正方形。
            gdata.vision.transforms.Resize(40),
            # 随机对高和宽各为 40 像素的正方形图片裁剪出面积为原图片面积 0.64 到 1 倍之间的小正方
            # 形,再放缩为高和宽各为 32 像素的正方形。
            gdata.vision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),
                                                      ratio=(1.0, 1.0)),
            # 随机左右翻转图片。
            gdata.vision.transforms.RandomFlipLeftRight(),
            # 将图片像素值按比例缩小到 0 和 1 之间,并将数据格式从“高 * 宽 * 通道”改为“通道 * 高 * 宽”。
            gdata.vision.transforms.ToTensor(),
            # 对图片的每个通道做标准化。
            gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],
                                              [0.2023, 0.1994, 0.2010])
        ])
    
    train_ds_transformed = train_ds.transform_first(train_ds )
    print  train_ds_transformed[0] #变换之后的数据

    重要的变换

    • Cast: 变换数据类型
    • ToTensor: 将图像数组由“高 * 宽 * 通道”改为 “通道 * 高 * 宽”
    • Normalize: 对图片(shape为通道 * 高 * 宽)每个通道上的每个像素按照均值和方差标准化
    • RandomResizedCrop: 首先按照一定的比例随机裁剪图像,然后再对图像变换高和宽
    • Resize: 将图像变换高和宽
    • RandomFlipLeftRight: 随机左右翻转

    例子

    # coding: utf-8
    from mxnet.gluon import data as gdata
    import multiprocessing
    import os
    
    def get_cifar10(root_dir,  batch_size, num_workers =  1):
        train_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'train'), flag=1)
        valid_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'valid'), flag=1)
        train_valid_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'train_valid'), flag=1)
        test_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'test'), flag=1)
    
        transform_train = gdata.vision.transforms.Compose([
            # 将图片放大成高和宽各为 40 像素的正方形。
            gdata.vision.transforms.Resize(40),
            # 随机对高和宽各为 40 像素的正方形图片裁剪出面积为原图片面积 0.64 到 1 倍之间的小正方
            # 形,再放缩为高和宽各为 32 像素的正方形。
            gdata.vision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),
                                                      ratio=(1.0, 1.0)),
            # 随机左右翻转图片。
            gdata.vision.transforms.RandomFlipLeftRight(),
            # 将图片像素值按比例缩小到 0 和 1 之间,并将数据格式从“高 * 宽 * 通道”改为
            # “通道 * 高 * 宽”。
            gdata.vision.transforms.ToTensor(),
            # 对图片的每个通道做标准化。
            gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],
                                              [0.2023, 0.1994, 0.2010])
        ])
    
        # 测试时,无需对图像做标准化以外的增强数据处理。
        transform_test = gdata.vision.transforms.Compose([
            gdata.vision.transforms.ToTensor(),
            gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],
                                              [0.2023, 0.1994, 0.2010])
        ])
        train_ds = train_ds.transform_first(transform_train)
        valid_ds = valid_ds.transform_first(transform_test)
        train_valid_ds = train_valid_ds.transform_first(transform_train)
        test_ds = test_ds.transform_first(transform_test)
        train_data = gdata.DataLoader(train_ds, batch_size, shuffle=True, last_batch='keep',num_workers = num_workers)
        valid_data = gdata.DataLoader(valid_ds, batch_size, shuffle=False, last_batch='keep', num_workers = num_workers)
        train_valid_data = gdata.DataLoader(train_valid_ds, batch_size, shuffle=True, last_batch='keep', num_workers=num_workers)
        test_data = gdata.DataLoader(test_ds, batch_size, shuffle=False, last_batch='keep', num_workers=num_workers)
        return train_data, valid_data, train_valid_data, test_data
    
    
    if __name__ == '__main__':
        batch_size = 256
        root_dir =  '/home/face/common/samples/cifar-10/train_valid_test'
    
        train_data, valid_data, train_valid_data, test_data = get_cifar10(root_dir, batch_size)
    
        for batch in train_data:
            data, label = batch
            print data.shape, label
  • 相关阅读:
    一些你可能用到的代码
    iOS 键盘下去的方法
    iOS设计模式汇总
    随笔
    Spring cloud config 分布式配置中心 (三) 总结
    Spring cloud config 分布式配置中心(二) 客户端
    Spring cloud config 分布式配置中心(一) 服务端
    jdbcUrl is required with driverClassName spring boot 2.0版本
    JpaRepository接口找不到 spring boot 项目
    解决IntelliJ “Initialization failed for 'https://start.spring.io'
  • 原文地址:https://www.cnblogs.com/psztswcbyy/p/11673044.html
Copyright © 2011-2022 走看看