zoukankan      html  css  js  c++  java
  • 4、交叉熵与softmax

    1、交叉熵的来源

    一条信息的信息量大小和它的不确定性有很大的关系,不确定性越大,则信息量越大。一句话如果需要很多外部信息才能确定,我们就称这句话的信息量比较大。比如你听到“云南西双版纳下雪了”,那你需要去看天气预报、问当地人等等查证(因为云南西双版纳从没下过雪)。相反,如果和你说“人一天要吃三顿饭”,那这条信息的信息量就很小,因为这条信息的确定性很高。

    将事件$x_0$的信息量定义如下(其中p($x_0$)表示事件$x_0$发生的概率):

    是表示随机变量不确定性的度量,是对所有可能发生的事件产生的信息量的期望。公式如下:

    相对熵又称KL散度,用于衡量同一个随机变量x的两个分布p(x)和q(x)之间的差异(距离)。在机器学习中,p(x)常用于描述样本的真实分布,例如[1,0,0,0]表示样本属于第一类,而q(x)则常常用于表示预测的分布,例如[0.7,0.1,0.1,0.1]。显然使用q(x)来描述样本不如p(x)准确,q(x)需要不断地学习来拟合准确的分布p(x)。
    KL散度的公式如下:

    KL散度的值越小表示两个分布越接近。

    我们将KL散度的公式进行变形,得到:
    前半部分就是p(x)的熵,后半部分就是我们的交叉熵
    机器学习中,我们常常使用KL散度来评估predict和label之间的差别,但是由于KL散度的前半部分是一个常量,所以我们常常将后半部分的交叉熵作为损失函数,其实二者是一样的。
     
    上面是从相对熵的角度推导出交叉熵的公式,同时我们可以从极大似然估计的角度推导出模型的损失函数,可以发现最小化交叉熵和最小化负对数似然函数是等价的。

    对于输入的[公式],其对应的类标签为[公式],我们的目标是找到这样的[公式]使得[公式]最大。在二分类的问题中,我们有:

    [公式]

    其中,[公式]是模型预测的概率值,[公式]是样本对应的类标签。

    将问题泛化为更一般的情况,多分类问题:

    [公式]

    由于连乘可能导致最终结果接近0的问题,一般对似然函数取对数的负数,变成最小化对数似然函数。

    [公式]

     
     
    2、分类问题中,loss函数不使用MSE而使用CE的原因 【https://github.com/HAOzj/Classic-ML-Methods-Algo/blob/master/ipynbs/appendix/loss_function/MSE%20vs%20Cross-entropy.ipynb】
    2.1、mse实际就是高斯分布的最大似然,crossEntropy是多项式分布的最大似然,分类问题当然得用多项式分布!(多分类问题的分布符合多项式分布,二分类问题的分布符合伯努利分布(二项分布)) 【参考 https://zhuanlan.zhihu.com/p/61944055 评论区】
    高斯分布的极大似然估计:https://zhuanlan.zhihu.com/p/346044291
    二项分布/多项式分布的极大似然估计:https://zhuanlan.zhihu.com/p/32341102
    2.2、MSE 多分类模型下损失函数:

    从MSE loss可以看出, MSE无差别得关注全部类别上预测概率和真实概率的差,交叉熵损失关注的是正确类别的预测概率,而我们最终目标是获得正确的类别.

    3、Softmax交叉熵损失函数

    【https://www.jianshu.com/p/1536f98c659c,https://zhuanlan.zhihu.com/p/27223959 】

    指数形式的原因:如果使用max函数,虽然能完美的进行分类但函数不可微从而无法进行训练,引入以 e 为底的指数并加权归一化,一方面指数函数使得结果将分类概率拉开了距离,另一方面函数可微。

    softmax函数求导:

    对每个样本,它属于类别[公式]的概率为:

    [公式]

    对softmax函数进行求导,即求

    [公式]

    [公式]项的输出对第[公式]项输入的偏导。
    代入softmax函数表达式,可以得到:

    [公式]

    用我们高中就知道的求导规则:对于

    [公式]

    它的导数为

    [公式]

    所以在我们这个例子中,

    [公式]

    上面两个式子只是代表直接进行替换,而非真的等式。

    [公式](即[公式])对[公式]进行求导,要分情况讨论:

    1. 如果[公式],则求导结果为[公式]
    2. 如果[公式],则求导结果为[公式]

    再来看[公式][公式]求导,结果为[公式]

    所以,当[公式]时:

    [公式]

    [公式]时:

    [公式]

    其中,为了方便,令[公式]

    softmax的计算与数值稳定性:

    在Python中,softmax函数为:

    def softmax(x):
        exp_x = np.exp(x)
        return exp_x / np.sum(exp_x)

    一种简单有效避免该问题的方法就是让exp(x)中的x值不要那么大或那么小,在softmax函数的分式上下分别乘以一个非零常数:

    [公式]

    这里[公式]是个常数,所以可以令它等于[公式]。加上常数[公式]之后,等式与原来还是相等的,所以我们可以考虑怎么选取常数[公式]。我们的想法是让所有的输入在0附近,这样[公式]的值不会太大,所以可以让[公式]的值为:

    [公式]

    这样子将所有的输入平移到0附近(当然需要假设所有输入之间的数值上较为接近),同时,除了最大值,其他输入值都被平移成负数,[公式]为底的指数函数,越小越接近0,这种方式比得到nan的结果更好。

  • 相关阅读:
    POJ 3261 Milk Patterns (求可重叠的k次最长重复子串)
    UVaLive 5031 Graph and Queries (Treap)
    Uva 11996 Jewel Magic (Splay)
    HYSBZ
    POJ 3580 SuperMemo (Splay 区间更新、翻转、循环右移,插入,删除,查询)
    HDU 1890 Robotic Sort (Splay 区间翻转)
    【转】ACM中java的使用
    HDU 4267 A Simple Problem with Integers (树状数组)
    POJ 1195 Mobile phones (二维树状数组)
    HDU 4417 Super Mario (树状数组/线段树)
  • 原文地址:https://www.cnblogs.com/ljygoodgoodstudydaydayup/p/15749386.html
Copyright © 2011-2022 走看看