zoukankan      html  css  js  c++  java
  • 如何解决回归任务数据不均衡的问题?

    摘要:现有的处理不平衡数据/长尾分布的方法绝大多数都是针对分类问题,而回归问题中出现的数据不均衡问题确极少被研究。

    本文分享自华为云社区《如何解决回归任务数据不均衡的问题?》,原文作者:PG13。

    现有的处理不平衡数据/长尾分布的方法绝大多数都是针对分类问题,而回归问题中出现的数据不均衡问题确极少被研究。但是,现实很多的工业预测场景都是需要解决回归的问题,也就是涉及到连续的,甚至是无限多的目标值,如何解决回归问题中出现的数据不均衡问题呢?ICML2021一篇被接收为Long oral presentation的论文:Delving into Deep Imbalanced Regression,推广了传统不均衡分类问题的范式,将数据不平衡问题从离散值域推广到了连续值域,并提出了两种解决深度不均衡回归问题的方法。

    主要的贡献是三个方面:1)提出了一个深度不均衡回归(Deep Imbalanced Regression, DIR)任务,定义为从具有连续目标的不平衡数据中学习,并能泛化到整个目标范围;2)提出了两种解决DIR的新方法,标签分布平滑(label distribution smoothing, LDS)和特征分布平滑(feature distribution smoothing, FDS),来解决具有连续目标的不平衡数据的学习问题;3)建立了5个新的DIR数据集,包括了CV、NLP、healthcare上的不平衡回归任务,致力于帮助未来在不平衡数据上的研究。

    数据不平衡问题背景

    现实世界的数据通常不会每个类别都具有理想的均匀分布,而是呈现出长尾的偏斜分布,其中某些目标值的观测值明显较少,这对于深度学习模型有较大的挑战。传统的解决办法可以分为基于数据基于模型两种:基于数据的解决方案无非对少数群体进行过采样和对多数群体进行下采样,比如SMOTE算法;基于模型的解决方案包括对损失函数的重加权(re-weighting)或利用相关的学习技巧,如迁移学习、元学习、两阶段训练等。

    但是现有的数据不平衡解决方案,主要是针对具有categorical index的目标值,也就是离散的类别标签数据。其目标值属于不同的类别,并且具有严格的硬边界,不同类别之间没有重叠。现实世界很多的预测场景可能涉及到连续目标值的标签数据。比如,根据人脸视觉图片预测年龄,年龄便是一个连续的目标值,并且在目标范围内可能会高度失衡。在工业领域中,也会发生类似的问题,比如在水泥领域,水泥熟料的质量,一般都是连续的目标值;在配煤领域,焦炭的热强指标也是连续的目标值。这些应用中需要预测的目标变量往往存在许多稀有和极端值。在连续域的不平衡问题在线性模型和深度模型中都是存在的,在深度模型中甚至更为严重,这是因为深度学习模型的预测往往都是over-confident的,会导致这种不平衡问题被严重的放大。

    因此,这篇文章定义了深度不平衡回归问题(DIR),即从具有连续目标值的不平衡数据中学习,同时需要处理某些目标区域的潜在确实数据,并使最终模型能够泛化到整个支持所有目标值的范围上。

    https://bbs-img.huaweicloud.com/blogs/img/images_162328840109677.png

    不平衡回归问题的挑战

    解决DIR问题的三个挑战如下:

    1.  对于连续的目标值(标签),不同目标值之间的硬边界不再存在,无法直接采用不平衡分类的处理方法。
    2.  连续标签本质上说明在不同的目标值之间的距离是有意义的。这些目标值直接告诉了哪些数据之间相隔更近,指导我们该如何理解这个连续区间上的数据不均衡的程度。
    3.  对于DIR,某些目标值可能根本没有数据,这为对目标值做extrapolation和interpolation提供了需求。

    解决方法一:标签分布平滑(LDS)

    首先通过一个例子展示一下当数据出现不均衡的时候,分类和回归问题之间的区别。作者在两个不同的数据集:(1)CIFAR-100,一个100类的图像分类数据集;(2)IMDB-WIKI,一个用于根据人像估算年龄(回归)的图像数据集,进行了比较。通过采样处理来模拟数据不平衡,保证两个数据集具有完全相同的标签密度分布,如下图所示:

    https://bbs-img.huaweicloud.com/blogs/img/images_162328846042796.png

    然后,分别在两个数据集上训练一个ResNet-50模型,并画出它们的测试误差的分布。从图中可以看出,在不平衡的分类数据集CIFAR-100上,测试误差的分布与标签密度的分布是高度负相关的,这很好理解,因为拥有更多样本的类别更容易学好。但是,连续标签空间的IMDB-WIKI的测试误差分布更加平滑,且不再与标签密度分布很好地相关。这说明了对于连续标签,其经验标签密度并不能准确地反映模型所看到的不均衡。这是因为相临标签的数据样本之间是相关的,相互依赖的。

    标签分布平滑:基于这些发现,作者提出了一种在统计学习领域中的核密度估计(LDS)方法,给定连续的经验标签密度分布,LDS使用了一个对称核函数k,用经验密度分布与之卷积,得到一个kernel-smoothed的有效标签密度分布,用来直观体现临近标签的数据样本具有的信息重叠问题,通过LDS计算出的有效标签密度分布结果与误差分布的相关性明显增强。有了LDS估计出的有效标签密度,就可以用解决类别不平衡问题的方法,直接应用于解决DIR问题。比如,最简单地一种make sence方式是利用重加权的方法,通过将损失函数乘以每个目标值的LDS估计标签密度的倒数来对其进行加权。

    https://bbs-img.huaweicloud.com/blogs/img/images_162328850124979.png

    解决方法二:特征分布平滑(FDS)

    如果模型预测正常且数据是均衡的,那么label相近的samples,它们对应的feature的统计信息应该也是彼此接近的。这里作者也举了一个实例验证了这个直觉。作者同样使用对IMDB-WIKI上训练的ResNet-50模型。主要focus在模型学习到的特征空间,不是标签空间。我们关注的最小年龄差是1岁,因此我们将标签空间分为了等间隔的区间,将具有相同目标区间的要素分到同一组。然后,针对每个区间中的数据计算其相应的特征统计量(均值、方差)。特征的统计量之间的相似性可视化为如下图:

    https://bbs-img.huaweicloud.com/blogs/img/images_162328853651222.png
    红色区间代表anchor区间,计算这个anchor label与其他所有label的特征统计量(即均值、方差)的余弦相似度。此外,不同颜色区域(紫色,黄色,粉红色)表示不同的数据密度。从图中可以得到两个结论:

    1.  anchor label和其临近的区间的特征统计量是高度相似的。而anchor label = 30 刚好是在训练数据量非常多的区域。这说明了,当有足够多的数据时,特征的统计量在临近点是相似的。
    2.  此外,在数据量很少的区域,如0-6岁的年龄范围,与30岁年龄段的特征统计量高度相似。这种不合理的相似性是由于数据不均衡造成的。因为,0-6岁的数据很少,该范围的特征会从具有最大数据量的范围继承其先验。

    特征分布平滑:受到这些启发,作者提出了特征分布平滑(FDS)。FDS是对特征空间进行分布的平滑,本质上是在临近的区间之间传递特征的统计信息。此过程的主要作用是去校准特征分布的潜在的有偏差的估计,尤其是对那些样本很少的目标值而言。

    https://bbs-img.huaweicloud.com/blogs/img/images_162328857000880.png
    具体来说,有一个模型,f代表一个encoder将输入数据映射到隐层的特征,g作为一个predictor来输出连续的预测目标值。FDS会首先估计每个区间特征的统计信息。这里用特征的协方差代替方差,来反映特征z内部元素之间的关系。给定特征统计量,再次使用对称核函数k来smooth特征均值和协方差的分布,这样可以拿到统计信息的平滑版本。利用估计和平滑统计量,遵循标准的whitening and re-coloring过程来校准每个输入样本的特征表示。那么整个FDS过程可以通过在最终特征图之后插入一个特征的校准层,实现将FDS集成到深度网络中。最后,在每个epoch采用了动量更新,来获得对训练过程中特征统计信息的一个更稳定和更准确的估计。

    基准DIR数据集

    1.  IMDB-WIKI-DIR(vision, age):基于IMDB-WIKI数据集,从包含人面部的图像来推断估计相应的年龄。
    2.  AgeDB-DIR(vision, age):基于AgeDB数据集,同样是根据输入图像进行年龄估计。
    3.  NYUD2-DIR(vision, depth):基于NYU2数据集,用于构建depth estimation的DIR任务。
    4.  STS-B-DIR(NLP, test similarity score):基于STS-B数据集,任务是推断两个输入句子之间的语义文本的相似度得分。
    5.  SHHS-DIR(Healthcare, health condition score):基于SHHS数据集,该任务是推断一个人的总体健康评分。

    具体的实验可以查看该论文,这里附上论文原文以及代码地址:

    [论文]:https://arxiv.org/abs/2102.09554

    [代码]:https://github.com/YyzHarry/imbalanced-regression

    点击关注,第一时间了解华为云新鲜技术~

  • 相关阅读:
    PAT 甲级 1027 Colors in Mars
    PAT 甲级 1026 Table Tennis(模拟)
    PAT 甲级 1025 PAT Ranking
    PAT 甲级 1024 Palindromic Number
    PAT 甲级 1023 Have Fun with Numbers
    PAT 甲级 1021 Deepest Root (并查集,树的遍历)
    Java实现 蓝桥杯VIP 算法训练 无权最长链
    Java实现 蓝桥杯VIP 算法训练 无权最长链
    Java实现 蓝桥杯 算法提高 抽卡游戏
    Java实现 蓝桥杯 算法提高 抽卡游戏
  • 原文地址:https://www.cnblogs.com/huaweiyun/p/14874058.html
Copyright © 2011-2022 走看看