zoukankan      html  css  js  c++  java
  • 阅读笔记 The Impact of Imbalanced Training Data for Convolutional Neural Networks [DegreeProject2015] 数据分析型

    The Impact of Imbalanced Training Data for Convolutional Neural Networks

    Paulina Hensman and David Masko

    摘要

         本论文从实验的角度调研了训练数据的不均衡性对采用CNN解决图像分类问题的性能影响。CIFAR-10数据集包含10个不同类别的60000个图像,用来构建不同类间分布的数据集。例如,一些训练集中包含一个类别的图像数目与其他类别的图像数目比例失衡。用这些训练集分别来训练一个CNN,度量其得到的网络的分类性能。实验结果表明:不均衡的训练数据对CNN的整体性能可能具有严重的负面影响,而均衡的训练数据能产生最好的性能。Oversampling技术在不均衡训练数据上可以将性能提升到均衡数据上的水平,所以它是一种对抗不均衡性的重要技术。

    概况

         在过去的几年里,由于在诸如机器视觉、语音识别及自然语言处理等几个领域获得重大突破,人工神经网络(Artificial Neural Networks)得到广泛的关注。没有任何先验与假设,这些网络采用统计的方法可以近似大量数据中潜在的函数与模式。DNN(Deep Neural Networks)以及CNN(Convolutional Neural Networks)两类特殊的神经网络是常用来解决复杂问题的现代方法。不利的一面是,为了学习到一个令人满意的神经网络,通常需要大量的数据。对于有监督的学习,还需要大量的标注数据。众所周知,标注数据通常是依赖于人工标注获得,因此获取困难。有一些标注好的图像数据是公开可用的,这些数据为研究与应用人员提供了标准资源,便于比较不同分类方法的,用以证明在该领域取得了一些进展。经验上讲,平衡数据集优于非平衡数据集,然而在真实的情况下,可用的数据集通常是不均衡的。如何处理不均衡数据是机器学习中一个很大的挑战。一些方法能够减轻不均衡数据带来的影响,但是并没有系统的研究结果表明DNN与CNN在标准数据集上如何受不均衡数据的影响。

         本文重点研究由于训练数据的类别不均衡带来的CNN分类性能的损失。由此进一步探索:什么类型的分布对性能有损?Oversampling在提升性能方面起多大的作用?具体来讲主要包含以下四个问题:

    (1)训练数据中均衡的类别分别对CNN的重要性有多大?

    (2)CNN的性能如何受训练数据中不同类别分布的影响?

    (3)通过调整训练数据的类别分布能否改善CNN的性能?

    (4)有什么可行的方法来实现这种调整?

    图像分类是判断给定的图像属于哪一类别的过程,直观来讲,就是图像包含了哪些物体。图像分类主要有两种形式:图像级别标注与对象级别标注。图像级别标注是一个二值变量,用来指示一个对象是否出现在图像上,例如,图像上是否有一只猫。对象级别的标注是具体到对象在图像中出现的位置。例如,螺丝刀中心位于(20,25),宽为50像素,高为30像素。本文关注图像级别的标注。

    不均衡数据是指机器学习算法在训练的过程中所采用的数据在不同类别上的分布是不均衡的。由于采用均衡数据学习的算法性能远优于不均衡数据的,所以不均衡数据给分类问题带来了挑战。实际中可用的数据通常是不均衡的。然而,大多数的学习算法假设训练数据是均衡,也同样假设未标注的数据也是类间均衡的。若训练数据的分布于测试集并不相同,这类算法通常会降低性能。进一步来讲,多数算法的目标在于最小化整体的错误率,这会导致训练数据中的小众类由于训练数据少而性能不佳。当小众类非常重要时,这种影响是完全负面的。例如,罕见疾病的诊断。不均衡数据已经得到了广泛的关注,有许多有效的方法可以解决这个问题。

    已有提升不均衡数据上的学习性能的方法大致分为三类: (1)sampling techniques;(2)Cost sensitive techniques;(3)One-class learning。采样技术改变原始的数据集,从而创建均衡数据集。简单的采样技术包括oversampling(从小众类中重复采样直至均衡),undersampling(移除over-represented类别的数据)与其他采样技术。然而有研究表明将oversampling与undersampling结合可能是应对极端不均衡数据的方式。

    • budget-sensitive progressive sampling algorithm

           训练数据数目n

    该采样策略依赖于几个假设:(1)与获取训练数据相比,学习算法的执行代价是可以忽略的,因为在该采样算法中学习算法需要运行多次。当训练数据获取代价高时,这一点是成立的。(2)假设每个类别的获取代价是相同的。这样的话预算数目n与训练实例数是一致的。这个假设大多时候是成立的,但也有例外。如,先前提及的电话数据,获取普通消费者和商业电话的代价是一样的,但是欺诈电话的识别代价是高昂的。

    •  combination of cost-sensitive technique and undersampling
    实验设置

    数据集:选用CIFAR-10,包含10个不同的类别,数据集较小,仅包含60000左右的images(不选择ImageNet的原因),便于做批量的实验,但又不至于任务太简单(如MNIST)

    数据集划分:5000 images per category for training and 1000 for testing

    类别分布:选择11个不同的类别分布,分别考察其分类性能,每种分布其实都是具有代表性的,毕竟10个类别的分布均衡,是很难量化的一个指标,所以这里只是举出几个典型的例子来说明。在本文中,并没有给出class imbalance的一个明确的量化的定义。

    网络结构:use caffe to create and train a CNN

    参数设置:3 convolutional layers and 10 output nodes, trained with learning rate 0.001 for 8 epochs + learning rate 0.0001 for 2 epochs, momentum set to 0.9, weight decay to 0.004

    测试数据:mean results of three runs

    评价指标:the percentage of correct answers for each class,然后再做平均。

    实验结果

    (1)数据越均衡,分类性能越好 

    (2)oversampling可以给imbalance 数据带来性能的提升,数据越不均衡提升越明显。 

  • 相关阅读:
    ViewState与Session [转]
    HTML5和HTML4的主要区别 [转]
    委托 与 Lambda
    ArcGIS 基础4-删除数据
    ArcGIS 基础3-新建数据
    ArcGIS 基础2-编辑数据
    ArcGIS 基础1-打开地图文档并浏览
    成都地铁线路图
    矢量数据库合并工具
    ArcGIS Pro试用下载步骤
  • 原文地址:https://www.cnblogs.com/shuzirank/p/5778566.html
Copyright © 2011-2022 走看看