zoukankan      html  css  js  c++  java
  • 十分钟看懂神经网络反向传输算法

    昨天面试被问到如何推导BP(反向传输)算法,顿时蒙住了,大体是知道反向传输算法的过程的,但是用语言描述出来,确实有些困难。回来后看了些博文,发现有的博文中公式推导过于复杂,不易理解,遂综合了网络中其他博文和斯坦福大学CS231n课程中的内容,整理了一份反向传输算法的通俗解释,如有错误,请各位网友指出。

    一、反向传输(BP)算法的作用是什么?

    首先我们要知道我们的优化目标是什么,对于神经网络模型的优化实质上就是对整体损失函数 L(成本函数) 的优化

    其中 L为样本集中第 i 个样本的损失值,(xi,yi)为第i个样本。

    损失函数 L 的自变量是网络中所有的参数,训练的目的是找到一组参数,使得损失函数 L 达到最小值(或者局部最小值)。通常使用梯度下降算法进行优化参数,对于具体的优化算法,这里不再叙述,具体可以参看深度学习——优化器算法Optimizer详解BGDSGDMBGDMomentumNAGAdagradAdadeltaRMSpropAdam一文。这些算法都是阐述了如何更好的使用梯度信息来快速优化成本函数、找到最有解,但这些算法的前提都是获得了成本函数的梯度值,对于深度网络可能有上亿的参数需要优化,如何高效的求解出 L 对这上亿参数的偏导数,便成为一个难题,反向传输(BP)算法即用来高效的计算这些参数的偏导数,进而得出成本函数(损失函数L)的梯度。

    补充:

    对于欲优化的代价函数,为权重的函数C=C(W),是一个非常复杂的复合函数,直接使用链式法则对各参数求偏导非常复杂。如果使用导数定义进行树数值求导,一个很明显的计算方式是使用近似:

    其中是一个大于零的小正数。换句话说,我们可以通过计算两个差距很小的wj的代价,然后利用上式来估计。这个方法看起来很不错,使用的是导数的定义求法,但是计算过于复杂,设想我们的神经网络中有100万个待优化的参数,计算一次梯度就要重复计算100万次。进而对于每一个样本,都要在神经网络中向前传播100万次。而且还需要计算C(W),相当于要一百万零一次前向传播。

    反向传播的优点在于他尽力用一次前向传播加一次反向传播就可以同时计算出所有的偏导数,大致来讲,反向传播所需要的总计算量与前向传播的计算量基本相等,原因在于前向传播时主要的计算量在于权重矩阵的乘法计算,反向传播时主要的计算量在于权重矩阵转置的乘法,很明显,它们的计算量差不多。利用反向传播算法求成本函数的梯度,大大减小了计算量,使神经网络优化更快!

    举个例子:

    二、反向传输算法过程

    1.链式求导

    首先我们回顾一下微积分中对于复合函数的求导过程,对于任意复合函数,例如:

    这就是我们常说的链式求导法则。反向传输算法正是利用了这种链式求导法则。

    2.用计算图来解释几种求导方法

    2.1 计算图

    式子 e=(a+b)*(b+1) 可以用如下计算图表达:

    令a=2,b=1则有:

    如何在计算图上表达“求导”呢? 导数的含义是因变量随自变量的变化率,例如 frac{partial y }{partial x} = 3  表示当x变化1个单位,y会变化3个单位。 微积分中已经学过:加法求导法则是 frac{partial}{partial a}(a+b) = frac{partial a}{partial a} + frac{partial b}{partial a} = 1 乘法求导法则是 frac{partial}{partial u}uv = ufrac{partial v}{partial u} + vfrac{partial u}{partial u} = v 。 我们在计算图的边上表示导数或偏导数:frac{ partial e }{ partial c } , frac{ partial e }{ partial d }, frac{ partial c }{ partial a }, frac{ partial c }{ partial b }, frac{ partial d }{ partial b } 如下图

    那么 frac{ partial e  }{ partial b } 如何求呢? frac{partial c }{ partial b} = 1 告诉我们1个单位的b变化会引起1个单位的c变换,frac{partial e }{ partial c} = 2告诉我们 1 个单位的c变化会引起2个单位的e变化。所以 frac{ partial e  }{ partial b } =   frac{ partial c }{ partial b } * frac{ partial e  }{ partial c }   = 1*2 =2 吗? 答案必然是错误。因为这样做只考虑到了下图橙色的路径,所有的路径都要考虑:frac{ partial e  }{ partial b } =   frac{ partial c }{ partial b } * frac{ partial e  }{ partial c }  +  frac{ partial d  }{ partial b }  *  frac{ partial e  }{ partial d }  =1*2 + 1 * 3 = 5

    所以上面的求导方法总结为一句话就是: 路径上所有边相乘,所有路径相加。不过这里需要补充一条很有用的合并策略:

    例如:下面的计算图若要计算frac{partial Z}{partial X}就会有9条路径:frac{partial Z}{partial X} = alphadelta + alphaepsilon + alphazeta + etadelta + etaepsilon + etazeta + gammadelta + gammaepsilon + gammazeta

    如果计算图再复杂一些,层数再多一些,路径数量就会呈指数爆炸性增长。但是如果采用合并策略:frac{partial Z}{partial X} = (alpha + eta + gamma)(delta + epsilon + zeta) 就不会出现这种问题。这种策略不是 对每一条路径都求和,而是 “合并同类路径”,“分阶段求解”。先求X对Y的总影响 (alpha + eta + gamma) 再求Y对Z的总影响 (delta + epsilon + zeta) 最后综合在一起。

    2.2 两种求导模式:前向模式求导( forward-mode differentiation) 反向模式求导(reverse-mode differentiation)

    上面提到的求导方法都是前向模式求导( forward-mode differentiation) :从前向后。先求X对Y的总影响 (alpha + eta + gamma) 再乘以Y对Z的总影响 (delta + epsilon + zeta) 。

    另一种,反向模式求导(reverse-mode differentiation) 则是从后向前。先求Y对Z的影响再乘以X对Y的影响。

    前向求导模式追踪一个输入如何影响每一个节点(对每一个节点进行 frac{partial}{partial X}操作)反向求导模式追踪每一个节点如何影响一个输出(对每一个节点进行 frac{partial Z}{partial}操作)。通俗点理解,前向求导模式和反向求导模式只是求导顺序的不同,但是顺序不同运算复杂度也不相同。

    2.3 反向求导模式(反向传播算法)的重要性

    让我们再次考虑前面的例子:

    如果用前向求导模式:关于b向前求导一次

    如果用反向求导模式:向后求导

    前向求导模式只得到了关于输入b的偏导 frac{partial e}{partial b} ,还需要再次求解关于输入a的偏导frac{partial e}{partial a} (运算2遍)。而反向求导一次运算就得到了e对两个输入a,b的偏导frac{partial e}{partial a}, frac{partial e}{partial b} (运算1遍)。上面的比较只看到了2倍的加速。但如果有1亿个输入1个输出,意味着前向求导需要操作1亿遍才得到所有关于输入的偏导,而反向求导则只需一次运算,1亿倍的加速。

    当我们训练神经网络时,把“损失“ 看作 ”权重参数“ 的函数,需要计算”损失“关于每一个”权重参数“的偏导数(然后用梯度下降法学习)。 神经网络的权重参数可以是百万甚至过亿级别。因此反向求导模式(反向传播算法)可以极大的加速学习。

    用更通俗易懂的话来描述反向传输算法,从目标函数开始,逐层向前求解每一层每个结点(运算)的局部梯度,根据链式法则可知,整个网络中成本函数对于某一个参数的偏导数,从成本函数流经本参数所有指路上的偏导数乘积叠加,这样一次运算就可以获得所有参数的偏导数,即成本函数的梯度。梯度从后向前逐层传递。

    通过计算流图对该算法进行简要解释:


     附录:机器学习常用求导公式

     X是向量,W是与X无关的矩阵

    后两个很重要!!!

     

  • 相关阅读:
    Illegal mix of collations (latin1_swedish_ci,COERCIBLE) and (gbk_chinese_ci,COERCIBLE) for operation '=' 一个解决办法(转载)
    mysql limit用法
    preparedStatement一个小技巧
    两个简单的压力测试代码。
    cookie实现session机制
    java.util.properties用法
    数据库是否使用外键,及视图,索引,存储过程的一些说明(zz)
    某项目要调用现有的100多个DLL 二 最最简单原型的思考
    面试题:红绿灯
    一个简单的封装 .net的日志功能
  • 原文地址:https://www.cnblogs.com/guoyaohua/p/8617086.html
Copyright © 2011-2022 走看看