zoukankan      html  css  js  c++  java
  • 手写数字识别[paddle框架]:1.数据处理

    |数据处理

    上一节,我们通过调用飞桨提供的API(paddle.dataset.mnist)加载MNIST数据集。但在工业实践中,我们面临的任务和数据环境千差万别,通常需要自己编写适合当前任务的数据处理程序,一般涉及如下五个环节:

    • 读入数据
    • 划分数据集
    • 生成批次数据
    • 训练样本集乱序
    • 校验数据有效性

    1.读入数据

    实际应用中,数据保存的格式多种多样,需要按一定规则读入数据data,这部分没有通用方法。

    2.划分数据及

    通常会把data划分为三个子集:train_set(训练集),val_set(验证集),test_set(测试集)

    • train_set:用于确定模型参数
    • val_set:用于调节模型超参数(如多个网络结构、正则化权重的最优选择)。
    • test_set:用于估计应用效果(没有在模型中应用过的数据,更贴近模型在正式场景应用的效果)

    3.训练样本乱序

    常见操作:先将样本按顺序进行编号,建立ID集合index_list。然后将index_list乱序,最后按乱序后的顺序读取数据。

    说明: 通过大量实验发现,模型对最后出现的数据印象更加深刻。训练数据导入后,越接近模型训练结束,最后几个批次数据对模型参数的影响越大。为了避免模型记忆影响训练效果,需要进行样本乱序操作。

    4.生成批次数据

    说明:

    在实际问题中,数据集往往非常大,如果每次都使用全量数据进行计算,效率非常低,通俗地说就是“杀鸡焉用牛刀”。由于参数每次只沿着梯度反方向更新一点点,因此方向并不需要那么精确。一个合理的解决方案是每次从总的数据集中随机抽取出小部分数据来代表整体,基于这部分数据计算梯度和损失来更新参数,这种方法被称作随机梯度下降法(Stochastic Gradient Descent,SGD),核心概念如下:

    • mini-batch:每次迭代时抽取出来的一批数据被称为一个mini-batch。
    • batch_size:一个mini-batch所包含的样本数目称为batch_size。
    • epoch:当程序迭代的时候,按mini-batch逐渐抽取出样本,当把整个数据集都遍历到了的时候,则完成了一轮训练,也叫一个epoch。启动训练时,可以将训练的轮数num_epoches和batch_size作为参数传入。

    通常操作:

    生成批次数据:先设置合理的batch_size,再讲数据转变为符合模型输入要求的np.array格式返回。同时,再返回数据时使用python生成器yield,以减少内存占用。

    在执行如上两个操作之前,需要先将数据处理代码封装为load_data函数,方便后续调用。

    5.校验数据有效性

    在实际应用中,原始数据可能存在标注不准确、数据杂乱或格式不统一等情况。因此在完成数据处理流程后,还需要进行数据校验,一般有两种方式:

    • 机器校验:加入一些校验和清理数据的操作。
    • 人工校验:先打印数据输出结果,观察是否是设置的格式。再从训练的结果验证数据处理和读取的有效性。

    机器校验

    如下代码所示,如果数据集中的图片数量和标签数量不等,说明数据逻辑存在问题,可使用assert语句校验图像数量和标签数据是否一致。

     imgs_length = len(imgs)
    
        assert len(imgs) == len(labels), 
        "length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(labels))
    

    人工校验

    人工校验是指打印数据输出结果,观察是否是预期的格式。实现数据处理和加载函数后,我们可以调用它读取一次数据,观察数据的shape和类型是否与函数中设置的一致。

    # 声明数据读取函数,从训练集中读取数据
    train_loader = data_generator
    # 以迭代的形式读取数据
    for batch_id, data in enumerate(train_loader()):
        image_data, label_data = data
        if batch_id == 0:
            # 打印数据shape和类型
            print("打印第一个batch数据的维度,以及数据的类型:")
            print("图像维度: {}, 标签维度: {}, 图像数据类型: {}, 标签数据类型: {}".format(image_data.shape, label_data.shape, type(image_data), type(label_data)))
        break
    

    6.封装数据读取与处理函数

    上文,我们从读取数据、划分数据集、到打乱训练数据、构建数据读取器以及数据数据校验,完成了一整套一般性的数据处理流程,下面将这些步骤放在一个函数中实现,方便在神经网络训练时直接调用。

    def load_data(mode='train'):
        datafile = './work/mnist.json.gz'
        print('loading mnist dataset from {} ......'.format(datafile))
        # 加载json数据文件
        data = json.load(gzip.open(datafile))
        print('mnist dataset load done')
        # mode_string = ['train', 'valid', 'eval']
        # mode_dict = dict([(s,idx) for idx,s in enumerate(mode_string)])
       
        # 读取到的数据区分训练集,验证集,测试集
        train_set, val_set, eval_set = data
        if mode=='train':
            # 获得训练数据集
            imgs, labels = train_set[0], train_set[1]
        elif mode=='valid':
            # 获得验证数据集
            imgs, labels = val_set[0], val_set[1]
        elif mode=='eval':
            # 获得测试数据集
            imgs, labels = eval_set[0], eval_set[1]
        else:
            raise Exception("mode can only be one of ['train', 'valid', 'eval']")
        # if not mode in mode_string:
        #     raise Exception("mode can only be one of ['train', 'valid', 'eval']")
        # else:
        #     imgs, labels = data[mode_dict[mode]][0], data[mode_dict[mode]][1]
        print("训练数据集数量: ", len(imgs))
        
        # 校验数据
        imgs_length = len(imgs)
    
        assert len(imgs) == len(labels), 
              "length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(label))
        
        # 获得数据集长度
        imgs_length = len(imgs)
        
        # 定义数据集每个数据的序号,根据序号读取数据
        index_list = list(range(imgs_length))
        # 读入数据时用到的批次大小
        BATCHSIZE = 100
        
        # 定义数据生成器
        def data_generator():
            if mode == 'train':
                # 训练模式下打乱数据
                random.shuffle(index_list)
            imgs_list = []
            labels_list = []
            for i in index_list:
                # 将数据处理成希望的格式,比如类型为float32,shape为[1, 28, 28]
                img = np.reshape(imgs[i], [1, 28, 28]).astype('float32')
                label = np.reshape(labels[i], [1]).astype('float32')
                imgs_list.append(img) 
                labels_list.append(label)
                if len(imgs_list) == BATCHSIZE:
                    # 获得一个batchsize的数据,并返回
                    yield np.array(imgs_list), np.array(labels_list)
                    # 清空数据读取列表
                    imgs_list = []
                    labels_list = []
        
            # 如果剩余数据的数目小于BATCHSIZE,
            # 则剩余数据一起构成一个大小为len(imgs_list)的mini-batch
            if len(imgs_list) > 0:
                yield np.array(imgs_list), np.array(labels_list)
        return data_generator
    
  • 相关阅读:
    centos7下部署nginx+supervisor+netcore2.1服务器环境
    centos6.1配置nodejs运行环境
    centos下远程访问redis端口配置
    如何成为一名合格的软件测试师
    Maven之安装及构建简单项目 掠影
    JAVA语言单元测试框架——JUnit浅析
    软件测试 之 白盒测试 掠影
    软件测试 之 黑盒测试 掠影
    以一个闰年检测程序为例的非法字符异常输入检测
    学习心得——测试框架浅析
  • 原文地址:https://www.cnblogs.com/Biiigwang/p/13811650.html
Copyright © 2011-2022 走看看