zoukankan      html  css  js  c++  java
  • gluon模块进行数据加载-Dataset和DataLoader

    在gluon接口中,通过Dataset和DataLoader来对数据集进行循环遍历,并返回batch大小的数据,其中Dataset对象用于数据的收集、加载和变换,而DataLoader对象用于返回batch大小的数据。

    1. 相关模块
    mxnet.gluon.data : 数据加载API
    mxnet.gluon.data.vision : 专门用于计算机视觉的数据集API和处理工具
    2. Dataset介绍
    Dataset对象用于收集数据、加载和变换数据,其中 数据加载是能够通过给定的下标获取对应的样本,数据变换可以对数据进行各种数据增广操作

    原理
    所有Dataset类中,都有以下四个方法:

    getitem(idx): 数据加载,用于返回第idx个样本
    len(): 用于返回数据集的样本的数量
    transform(fn, lazy = True): 数据变换,用于返回对每个样本利用fn函数进行数据变换(增广)后的Dataset
    transform_first(fn, lazy = True): 数据变换,用于返回对每个样本的特征利用fn函数进行数据变换(增广)后的Dataset,而不对label进行数据增广
    用法
    ArrayDataset

    import mxnet as mx
    ## 定义
    mx.random.seed(42) # 固定随机数种子,以便能够复现
    X = mx.random.uniform(shape = (10, 3))
    y = mx.random.uniform(shape = (10, 1))
    dataset = mx.gluon.data.ArrayDataset(X, y) # ArrayDataset不需要从硬盘上加载数据
    ## 使用
    dataset[5]  # 将返回第6个样本的特征和标签,(特征,标签)

    ImageFolderDataset

    import mxnet as mx
    ## 定义
    dataset= gdata.vision.ImageFolderDataset("样本集的根路径", flag=1)
    ## 使用
    dataset[5]  # 将返回第6个样本的特征和标签,(特征,标签)

    ImageRecordDataset

    import mxnet as mx
    ## 定义
    file = '/xxx/train.rec'
     #不需要指定idx文件路径,会从路径中自动拼接处idx的路径,例如此处为/xxx/train.idx
    dataset= gdata.vision.ImageRecordDataset(file) 
    ## 使用
    dataset[5]  # 将返回第6个样本的特征和标签,(特征,标签)

    API中的所有Dataset


    mxnet.gluon.data.Dataset: 抽象的数据集类

    mxnet.gluon.data.ArrayDataset: 组合多个Dataset的数据集类

    mxnet.gluon.data.RecordFileDataset: .rec文件的数据集类

    mxnet.gluon.data.vison.MNIST: MNIST数据集的Dataset

    mxnet.gluon.data.vison.FashionMNIST: FashionMNIST数据集的Dataset

    mxnet.gluon.data.vison.CIFAR10: CIFAR10数据集的Dataset

    mxnet.gluon.data.vison.CIFAR100: CIFAR100数据集的Dataset

    mxnet.gluon.data.vison.ImageRecordDataset: 含有图片的.rec文件的Dataset

    mxnet.gluon.data.vison.ImageFolderDataset: 存储图片在文件夹结构的Dataset

    说明: mxnet和numpy的array可以直接作为Dataset

    3. DataLoader介绍
    加载Dataset,迭代时返回batch大小的样本
    可以方便的并行地加载数据
    使用示例如下:

    from multiprocessing import cpu_count
    CPU_COUNT = cpu_count()
    data_loader = mx.gluon.data.DataLoader(dataset, batch_size = 5, num_workers = CPU_COUNT)
    for X, y in data_loader:
        print X.shape, y.shape

    4. transforms模块介绍
    在gloun的data接口中,有可以使用的数据增广的模块(mxnet.gluon.data.vision.tranforms)。在transforms模块中定义了很多数据变换的layer(为Block的子类),变换layer的输入为样本,输出为变换后的样本。

    用法示例

    from mxnet.gluon import data as gdata
    train_ds = gdata.vision.ImageFolderDataset("样本集的根路径", flag=1)
    print train_ds[0] #变换之前的数据
    ## 数据变换定义
    transform_train = gdata.vision.transforms.Compose([  # Compose将这些变换按照顺序连接起来
            # 将图片放大成高和宽各为 40 像素的正方形。
            gdata.vision.transforms.Resize(40),
            # 随机对高和宽各为 40 像素的正方形图片裁剪出面积为原图片面积 0.64 到 1 倍之间的小正方
            # 形,再放缩为高和宽各为 32 像素的正方形。
            gdata.vision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),
                                                      ratio=(1.0, 1.0)),
            # 随机左右翻转图片。
            gdata.vision.transforms.RandomFlipLeftRight(),
            # 将图片像素值按比例缩小到 0 和 1 之间,并将数据格式从“高 * 宽 * 通道”改为“通道 * 高 * 宽”。
            gdata.vision.transforms.ToTensor(),
            # 对图片的每个通道做标准化。
            gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],
                                              [0.2023, 0.1994, 0.2010])
        ])
    
    train_ds_transformed = train_ds.transform_first(train_ds )
    print  train_ds_transformed[0] #变换之后的数据

    重要的变换

    • Cast: 变换数据类型
    • ToTensor: 将图像数组由“高 * 宽 * 通道”改为 “通道 * 高 * 宽”
    • Normalize: 对图片(shape为通道 * 高 * 宽)每个通道上的每个像素按照均值和方差标准化
    • RandomResizedCrop: 首先按照一定的比例随机裁剪图像,然后再对图像变换高和宽
    • Resize: 将图像变换高和宽
    • RandomFlipLeftRight: 随机左右翻转

    例子

    # coding: utf-8
    from mxnet.gluon import data as gdata
    import multiprocessing
    import os
    
    def get_cifar10(root_dir,  batch_size, num_workers =  1):
        train_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'train'), flag=1)
        valid_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'valid'), flag=1)
        train_valid_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'train_valid'), flag=1)
        test_ds = gdata.vision.ImageFolderDataset(os.path.join(root_dir, 'test'), flag=1)
    
        transform_train = gdata.vision.transforms.Compose([
            # 将图片放大成高和宽各为 40 像素的正方形。
            gdata.vision.transforms.Resize(40),
            # 随机对高和宽各为 40 像素的正方形图片裁剪出面积为原图片面积 0.64 到 1 倍之间的小正方
            # 形,再放缩为高和宽各为 32 像素的正方形。
            gdata.vision.transforms.RandomResizedCrop(32, scale=(0.64, 1.0),
                                                      ratio=(1.0, 1.0)),
            # 随机左右翻转图片。
            gdata.vision.transforms.RandomFlipLeftRight(),
            # 将图片像素值按比例缩小到 0 和 1 之间,并将数据格式从“高 * 宽 * 通道”改为
            # “通道 * 高 * 宽”。
            gdata.vision.transforms.ToTensor(),
            # 对图片的每个通道做标准化。
            gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],
                                              [0.2023, 0.1994, 0.2010])
        ])
    
        # 测试时,无需对图像做标准化以外的增强数据处理。
        transform_test = gdata.vision.transforms.Compose([
            gdata.vision.transforms.ToTensor(),
            gdata.vision.transforms.Normalize([0.4914, 0.4822, 0.4465],
                                              [0.2023, 0.1994, 0.2010])
        ])
        train_ds = train_ds.transform_first(transform_train)
        valid_ds = valid_ds.transform_first(transform_test)
        train_valid_ds = train_valid_ds.transform_first(transform_train)
        test_ds = test_ds.transform_first(transform_test)
        train_data = gdata.DataLoader(train_ds, batch_size, shuffle=True, last_batch='keep',num_workers = num_workers)
        valid_data = gdata.DataLoader(valid_ds, batch_size, shuffle=False, last_batch='keep', num_workers = num_workers)
        train_valid_data = gdata.DataLoader(train_valid_ds, batch_size, shuffle=True, last_batch='keep', num_workers=num_workers)
        test_data = gdata.DataLoader(test_ds, batch_size, shuffle=False, last_batch='keep', num_workers=num_workers)
        return train_data, valid_data, train_valid_data, test_data
    
    
    if __name__ == '__main__':
        batch_size = 256
        root_dir =  '/home/face/common/samples/cifar-10/train_valid_test'
    
        train_data, valid_data, train_valid_data, test_data = get_cifar10(root_dir, batch_size)
    
        for batch in train_data:
            data, label = batch
            print data.shape, label
  • 相关阅读:
    程序员励志语录
    javaEE的十一种技术
    gui内函数调用顺序
    m文件中函数的执行顺序
    VC++与Matlab混合编程之引擎操作详解(6)数据类型mxArray的操作
    GUI(2)
    时间管理
    Matlab GUI界面
    matlab GUI(2)
    MATLAB GUI平台
  • 原文地址:https://www.cnblogs.com/psztswcbyy/p/11673044.html
Copyright © 2011-2022 走看看