zoukankan      html  css  js  c++  java
  • 机器学习常用损失函数

    机器学习常用损失函数

    转载自:机器学习常用损失函数小结 - 王桂波的文章 - 知乎 https://zhuanlan.zhihu.com/p/776861188

    1.Loss Function、Cost Function 和 Objective Function 的区别和联系

    • 损失函数 Loss Function 通常是针对单个训练样本而言,给定一个模型输出 [公式] 和一个真实 [公式],损失函数输出一个实值损失 [公式]

    • 代价函数 Cost Function 通常是针对整个训练集(或者在使用 mini-batch gradient descent 时一个 mini-batch)的总损失 [公式]

    • 目标函数 Objective Function 是一个更通用的术语,表示任意希望被优化的函数,用于机器学习领域和非机器学习领域(比如运筹优化)

      由于损失函数和代价函数只是在针对样本集上有区别,因此在本文中统一使用了损失函数这个术语,但下文的相关公式实际上采用的是代价函数 Cost Function 的形式。

    2.回归常用的损失函数

    01.均方差损失 Mean Squared Error Loss(MSE)

    • 均方差 Mean Squared Error (MSE) 损失是机器学习、深度学习回归任务中最常用的一种损失函数,也称为 L2 Loss。其基本形式如下

    [公式]

    从直觉上理解均方差损失,这个损失函数的最小值为 0(当预测等于真实值时),最大值为无穷大。下图是对于真实值 [公式] ,不同的预测值 [公式] 的均方差损失的变化图。横轴是不同的预测值,纵轴是均方差损失,可以看到随着预测与真实值绝对误差 [公式] 的增加,均方差损失呈二次方地增加。

    img

    • 背后的假设

    实际上在一定的假设下,我们可以使用最大化似然得到均方差损失的形式。假设模型预测与真实值之间的误差服从标准高斯分布[公式] ),则给定一个 [公式] 模型输出真实值 [公式] 的概率为

    [公式]

    进一步我们假设数据集中 N 个样本点之间相互独立,则给定所有 [公式] 输出所有真实值 [公式] 的概率,即似然 Likelihood,为所有 [公式] 的累乘

    [公式]

    通常为了计算方便,我们通常最大化对数似然 Log-Likelihood

    [公式]

    去掉与 [公式] 无关的第一项,然后转化为最小化负对数似然 Negative Log-Likelihood

    [公式]

    可以看到这个实际上就是均方差损失的形式。也就是说在模型输出与真实值的误差服从高斯分布的假设下,最小化均方差损失函数与极大似然估计本质上是一致的,因此在这个假设能被满足的场景中(比如回归),均方差损失是一个很好的损失函数选择;当这个假设没能被满足的场景中(比如分类),均方差损失不是一个好的选择。

    02.平均绝对误差损失 Mean Absolute Error Loss(MAE)

    [公式]

    同样的我们可以对这个损失函数进行可视化如下图,MAE 损失的最小值为 0(当预测等于真实值时),最大值为无穷大。可以看到随着预测与真实值绝对误差 [公式] 的增加,MAE 损失呈线性增长

    img

    • 背后的假设

      同样的我们可以在一定的假设下通过最大化似然得到 MAE 损失的形式,假设模型预测与真实值之间的误差服从拉普拉斯分布 Laplace distribution[公式] ),则给定一个 [公式] 模型输出真实值 [公式] 的概率为

      [公式]

      与上面推导 MSE 时类似,我们可以得到的负对数似然实际上就是 MAE 损失的形式

      [公式]

      对比MAE与MSE的区别

      MAE 和 MSE 作为损失函数的主要区别是:MSE 损失相比 MAE 通常可以更快地收敛,但 MAE 损失对于 outlier 更加健壮,即更加不易受到 outlier 影响。

      MSE 通常比 MAE 可以更快地收敛。当使用梯度下降算法时,MSE 损失的梯度为 [公式] ,而 MAE 损失的梯度为 [公式] ,即 MSE 的梯度的 scale 会随误差大小变化,而 MAE 的梯度的 scale 则一直保持为 1,即便在绝对误差 [公式] 很小的时候 MAE 的梯度 scale 也同样为 1,这实际上是非常不利于模型的训练的。当然你可以通过在训练过程中动态调整学习率缓解这个问题,但是总的来说,损失函数梯度之间的差异导致了 MSE 在大部分时候比 MAE 收敛地更快。这个也是 MSE 更为流行的原因。

      MAE 对于 outlier 更加 robust。我们可以从两个角度来理解这一点:

    • 第一个角度是直观地理解,下图是 MAE 和 MSE 损失画到同一张图里面,由于MAE 损失与绝对误差之间是线性关系,MSE 损失与误差是平方关系,当误差非常大的时候,MSE 损失会远远大于 MAE 损失。因此当数据中出现一个误差非常大的 outlier 时,MSE 会产生一个非常大的损失,对模型的训练会产生较大的影响。

    img

    • 第二个角度是从两个损失函数的假设出发,MSE 假设了误差服从高斯分布,MAE 假设了误差服从拉普拉斯分布。拉普拉斯分布本身对于 outlier 更加 robust。参考下图(来源:Machine Learning: A Probabilistic Perspective 2.4.3 The Laplace distribution Figure 2.8),当右图右侧出现了 outliers 时,拉普拉斯分布相比高斯分布受到的影响要小很多。因此以拉普拉斯分布为假设的 MAE 对 outlier 比高斯分布为假设的 MSE 更加 robust。

    img

    普及:拉普拉斯分布与高斯分布

    • 一元拉普拉斯(laplace)也叫双指数分布,可以和正态分布进行对比,其密度函数为:image-20201027123538987
      其中σ为尺度参数;μ 为位置参数

    在这里插入图片描述

    拉普拉斯分布关于μ 对称,并达到最大值/12σ

    在这里插入图片描述

    • 高斯分布(一般指一元高斯分布)又称为正态分布,是常见的连续概率分布。
      在这里插入图片描述
      高斯分布的重要性质:
      • 密度函数关于平均值对称
      • 平均值与他的众数、中位数为同一值
      • 函数曲线下68.268949%的面积在平均数左右的一个标准差范围内
      • 95.449974%的面积在平均数左右两个标准差2 σ 2 sigma2σ的范围内
      • 99.730020%的面积在平均数左右三个标准差3 σ 3 sigma3σ的范围内
      • 函数曲线的拐点(inflection point)为离平均数一个标准差距离的位置。

    03 Huber Loss

    上文我们分别介绍了 MSE 和 MAE 损失以及各自的优缺点,MSE 损失收敛快但容易受 outlier 影响,MAE 对 outlier 更加健壮但是收敛慢,Huber Loss 则是一种将 MSE 与 MAE 结合起来,取两者优点的损失函数,也被称作 Smooth Mean Absolute Error Loss 。其原理很简单,就是在误差接近 0 时使用 MSE,误差较大时使用 MAE,公式为

    [公式]

    上式中 [公式] 是 Huber Loss 的一个超参数,[公式] 的值是 MSE 和 MAE 两个损失连接的位置。上式等号右边第一项是 MSE 的部分,第二项是 MAE 部分,在 MAE 的部分公式为 [公式]是为了保证误差 [公式] 时 MAE 和 MSE 的取值一致,进而保证 Huber Loss 损失连续可导。

    下图是 [公式] 时的 Huber Loss,可以看到在 [公式] 的区间内实际上就是 MSE 损失,在[公式][公式] 区间内为 MAE损失。

    img

    • Huber Loss 的特点

    Huber Loss 结合了 MSE 和 MAE 损失,在误差接近 0 时使用 MSE,使损失函数可导并且梯度更加稳定;在误差较大时使用 MAE 可以降低 outlier 的影响,使训练对 outlier 更加健壮。缺点是需要额外地设置一个 [公式] 超参数。

    04 分位数损失 Quantile Loss

    分位数回归 Quantile Regression 是一类在实际应用中非常有用的回归算法,通常的回归算法是拟合目标值的期望或者中位数,而分位数回归可以通过给定不同的分位点,拟合目标值的不同分位数。例如我们可以分别拟合出多个分位点,得到一个置信区间,如下图所示(图片来自笔者的一个分位数回归代码 demo Quantile Regression Demo

    img

    分位数回归是通过使用分位数损失 Quantile Loss 来实现这一点的,分位数损失形式如下,式中的 r 分位数系数。

    [公式]

    我们如何理解这个损失函数呢?这个损失函数是一个分段的函数 ,将 [公式] (高估) 和[公式] (低估) 两种情况分开来,并分别给予不同的系数。当 [公式] 时,低估的损失要比高估的损失更大,反过来当 [公式] 时,高估的损失比低估的损失大;分位数损失实现了分别用不同的系数控制高估和低估的损失,进而实现分位数回归。特别地,当 [公式] 时,分位数损失退化为 MAE 损失,从这里可以看出 MAE 损失实际上是分位数损失的一个特例 — 中位数回归(这也可以解释为什么 MAE 损失对 outlier 更鲁棒:MSE 回归期望值,MAE 回归中位数,通常 outlier 对中位数的影响比对期望值的影响小)。

    [公式]

    下图是取不同的分位点 0.2、0.5、0.6 得到的三个不同的分位损失函数的可视化,可以看到 0.2 和 0.6 在高估和低估两种情况下损失是不同的,而 0.5 实际上就是 MAE。

    img

    3 分类常用的损失函数

    01 交叉熵损失

    二分类

    考虑二分类,在二分类中我们通常使用 Sigmoid 函数将模型的输出压缩到 (0, 1) 区间内[公式] ,用来代表给定输入 [公式] ,模型判断为正类的概率。由于只有正负两类,因此同时也得到了负类的概率。

    [公式]

    将两条式子合并成一条

    [公式]

    假设数据点之间独立同分布,则似然可以表示为

    [公式]

    对似然取对数,然后加负号变成最小化负对数似然,即为交叉熵损失函数的形式

    [公式]

    下图是对二分类的交叉熵损失函数的可视化,蓝线是目标值为 0 时输出不同输出的损失,黄线是目标值为 1 时的损失。可以看到约接近目标值损失越小,随着误差变差,损失呈指数增长。

    img

    多分类

    在多分类的任务中,交叉熵损失函数的推导思路和二分类是一样的,变化的地方是真实值 [公式] 现在是一个 One-hot 向量,同时模型输出的压缩由原来的 Sigmoid 函数换成 Softmax 函数。Softmax 函数将每个维度的输出范围都限定在 [公式] 之间,同时所有维度的输出和为 1,用于表示一个概率分布。

    [公式]

    其中 [公式] 表示 K 个类别中的一类,同样的假设数据点之间独立同分布,可得到负对数似然为

    [公式]

    由于 [公式] 是一个 one-hot 向量,除了目标类为 1 之外其他类别上的输出都为 0,因此上式也可以写为

    [公式]

    其中 [公式] 是样本 [公式] 的目标类。通常这个应用于多分类的交叉熵损失函数也被称为 Softmax Loss 或者 Categorical Cross Entropy Loss。

    Cross Entropy is good. But WHY?

    分类中为什么不用均方差损失?上文在介绍均方差损失的时候讲到实际上均方差损失假设了误差服从高斯分布,在分类任务下这个假设没办法被满足,因此效果会很差。为什么是交叉熵损失呢?有两个角度可以解释这个事情,一个角度从最大似然的角度,也就是我们上面的推导;另一个角度是可以用信息论来解释交叉熵损失:

    假设对于样本 [公式] 存在一个最优分布 [公式] 真实地表明了这个样本属于各个类别的概率,那么我们希望模型的输出 [公式] 尽可能地逼近这个最优分布,在信息论中,我们可以使用 KL 散度 Kullback–Leibler Divergence 来衡量两个分布的相似性。给定分布 [公式] 和分布 [公式] , 两者的 KL 散度公式如下

    [公式]

    其中第一项为分布 [公式] 的信息熵,第二项为分布 [公式][公式] 的交叉熵。将最优分布 [公式] 和输出分布[公式] 带入 [公式][公式] 得到

    [公式]

    由于我们希望两个分布尽量相近,因此我们最小化 KL 散度。同时由于上式第一项信息熵仅与最优分布本身相关,因此我们在最小化的过程中可以忽略掉,变成最小化

    [公式]

    我们并不知道最优分布 [公式] ,但训练数据里面的目标值 [公式] 可以看做是 [公式] 的一个近似分布

    [公式]

    这个是针对单个训练样本的损失函数,如果考虑整个数据集,则

    [公式]

    可以看到通过最小化交叉熵的角度推导出来的结果和使用最大 化似然得到的结果是一致的

    02 合页损失Hinge Loss

    合页损失 Hinge Loss 是另外一种二分类损失函数,适用于 maximum-margin 的分类,支持向量机 Support Vector Machine (SVM) 模型的损失函数本质上就是 Hinge Loss + L2 正则化。合页损失的公式如下

    [公式]

    下图是 [公式] 为正类, 即 [公式] 时,不同输出的合页损失示意图

    img

    可以看到当 [公式] 为正类时,模型输出负值会有较大的惩罚,当模型输出为正值且在 [公式] 区间时还会有一个较小的惩罚。即合页损失不仅惩罚预测错的,并且对于预测对了但是置信度不高的也会给一个惩罚,只有置信度高的才会有零损失。使用合页损失直觉上理解是要找到一个决策边界,使得所有数据点被这个边界正确地、高置信地被分类

    总结

    本文针对机器学习中最常用的几种损失函数进行相关介绍,首先是适用于回归的均方差损失 Mean Squared Loss、平均绝对误差损失 Mean Absolute Error Loss,两者的区别以及两者相结合得到的 Huber Loss,接着是应用于分位数回归的分位数损失 Quantile Loss,表明了平均绝对误差损失实际上是分位数损失的一种特例,在分类场景下,本文讨论了最常用的交叉熵损失函数 Cross Entropy Loss,包括二分类和多分类下的形式,并从信息论的角度解释了交叉熵损失函数,最后简单介绍了应用于 SVM 中的 Hinge 损失 Hinge Loss。本文相关的可视化代码在 这里

    受限于时间,本文还有其他许多损失函数没有提及,比如应用于 Adaboost 模型中的指数损失 Exponential Loss,0-1 损失函数等。另外通常在损失函数中还会有正则项(L1/L2 正则),这些正则项作为损失函数的一部分,通过约束参数的绝对值大小以及增加参数稀疏性来降低模型的复杂度,防止模型过拟合,这部分内容在本文中也没有详细展开。读者有兴趣可以查阅相关的资料进一步了解。That’s all. Thanks for reading.

    自律, 坚定, 随和, 坚强, 为了自己想要的,去奋斗
  • 相关阅读:
    Nginx,uWSGI与Django 应用的关系
    闭包学习-Python 篇
    Django学习之REST framework JWT Auth
    Python标准库uuid模块
    Django REST framework学习之JWT失效方式
    Django学习之JWT
    单点登录
    print输出格式总结
    百钱百鸡问题
    流程图符号及其功能
  • 原文地址:https://www.cnblogs.com/xiaowututu/p/13905267.html
Copyright © 2011-2022 走看看