zoukankan      html  css  js  c++  java
  • 交叉熵损失函数的求导(Logistic回归)

    前言

    最近有遇到些同学找我讨论sigmoid训练多标签或者用在目标检测中的问题,我想写一些他们的东西,想到以前的博客里躺着这篇文章(2015年读研时机器学课的作业)感觉虽然不够严谨,但是很多地方还算直观,就先把它放过来吧。

    说明: 本文只讨论Logistic回归的交叉熵,对Softmax回归的交叉熵类似(Logistic回归和Softmax回归两者本质是一样的,后面我会专门有一篇文章说明两者关系,先在这里挖个坑)。 首先,我们二话不说,先放出逻辑回归交叉熵的公式:

    [公式]

    以及 [公式] 对参数 [公式] 的偏导数(用于诸如梯度下降法等优化算法的参数更新),如下:

    [公式]

    但是在大多论文或数教程中,也就是直接给出了上面两个公式,而未给出推导过程,这就给初学者造成了一定的困惑。交叉熵的公式可以用多种解释得到,甚至不同领域也会有不同,比如数学系的用极大似然估计,信息工程系的的从信息编码角度,当然更多是联合KL散度来解释。但是我这里假设那些你都不了解的情况下如何用一个更加直白和直观的解释来得到Logistic Regression的交叉熵损失函数,说清楚它存在的合理性就可以解惑(关于交叉熵的所谓"正统"解释后续我会专门写一篇文章来总结,先挖个坑)。因水平有限,如有错误,欢迎指正。

    废话不说,下文将介绍一步步得到Logistic Regression的交叉熵损失函数,并推导出其导数,同时给出简洁的向量形式及其导数推导过程。

    交叉熵损失函数(Logistic Regression代价函数)

    我们一共有 [公式] 组已知样本( [公式] ), [公式] 表示第 [公式] 组数据及其对应的类别标记。其中 [公式][公式] 维向量(考虑偏置项), [公式] 则为表示类别的一个数:

    • logistic回归(是非问题)中, [公式] 取0或者1;
    • softmax回归 (多分类问题)中, [公式] 取1,2...k中的一个表示类别标号的一个数(假设共有k类)。

    这里,只讨论logistic回归,输入样本数据 [公式] ,模型的参数为 [公式] ,因此有

    [公式]

    二元问题中常用sigmoid作为假设函数(hypothesis function),定义为:

    [公式]

    因为Logistic回归问题就是0/1的二分类问题,可以有

    [公式]

    现在,我们不考虑“熵”的概念,根据下面的说明,从简单直观角度理解,就可以得到我们想要的损失函数:我们将概率取对数,其单调性不变,有

    [公式]

    那么对于第 [公式] 组样本,假设函数表征正确的组合对数概率为:

    [公式]

    其中, [公式][公式] 为示性函数(indicative function),简单理解为{ }内条件成立时,取1,否则取0,这里不赘言。 那么对于一共 [公式] 组样本,我们就可以得到模型对于整体训练样本的表现能力:

    [公式]

    由以上表征正确的概率含义可知,我们希望其值越大,模型对数据的表达能力越好。而我们在参数更新或衡量模型优劣时是需要一个能充分反映模型表现误差的损失函数(Loss function)或者代价函数(Cost function)的,而且我们希望损失函数越小越好。由这两个矛盾,那么我们不妨领代价函数为上述组合对数概率的相反数:

    [公式]

    上式即为大名鼎鼎的交叉熵损失函数。(说明:如果熟悉“信息熵"的概念 [公式] ,那么可以有助理解叉熵损失函数)

    交叉熵损失函数的求导

    这步需要用到一些简单的对数运算公式,这里先以编号形式给出,下面推导过程中使用特意说明时都会在该步骤下脚标标出相应的公式编号,以保证推导的连贯性。

    [公式]

    [公式]

    [公式] (为了方便这里 [公式][公式] ,即 [公式] ,其他底数如2,10等,由换底公式可知,只是前置常数系数不同,对结论毫无影响)

    另外,值得一提的是在这里涉及的求导均为矩阵、向量的导数(矩阵微商),这里有一篇教程总结得精简又全面,非常棒,推荐给需要的同学。

    下面开始推导:

    交叉熵损失函数为:

    [公式]

    其中,

    [公式]

    由此,得到

    [公式]

    这次再计算 [公式] 对第 [公式] 个参数分量 [公式] 求偏导:

    [公式]

    这就是交叉熵对参数的导数:

    [公式]

     

    向量形式

    前面都是元素表示的形式,只是写法不同,过程基本都是一样的,不过写成向量形式会更清晰,这样就会把 [公式] 和求和符号 [公式] 省略掉了。我们不妨忽略前面的固定系数项 [公式] ,交叉墒的损失函数(1)则可以写成下式:

    [公式]

    [公式] 带入,得到:

    [公式]

    再对 [公式] 求导,前面的负号直接削掉了,

    [公式]

     

    3 梯度下降参数更新

     

     

    转载请注明出处Jason Zhao的知乎专栏“人工+智能“,文章链接:

    Jason Zhao:交叉熵损失函数的求导(Logistic回归)

  • 相关阅读:
    Petya and Countryside
    大数A+B
    python-requests正则
    python-UnicodeDecodeError: 'gbk' codec can't decode byte 0xa8 in position 157: illegal multibyte sequence
    python-mysql数据迁移
    python-flask框架路由
    python-flask框架基础
    MYSQL-内外自连接-判断函数
    MYSQL-分组查询-where和having的区别
    mysql增删
  • 原文地址:https://www.cnblogs.com/celine227/p/15102377.html
Copyright © 2011-2022 走看看