zoukankan      html  css  js  c++  java
  • DNN中的梯度消失/爆炸原因

    梯度消失/梯度爆炸是深度学习中老生常谈的话题,这篇博客主要是对DNN中的梯度消失/梯度爆炸原因进行公式层面上的直观理解。

    如上图所示,假设有2个隐层,前向传播公式:

    (f_1 = sigma(w_1x+b_1),z_1 = w_1x+b_1)

    (f_2 = sigma(w_2f_1+b_2),z_2 = w_2f_1+b_2)

    (f_3 = sigma(w_3f_2+b_3),z_3 = w_3f_2+b_3)

    (f_3)是输出层的神经元,所以可以认为(loss)是关于(f_3)的函数。

    (loss)反向传播的时候,我们可以对权重(w_3, w_2, w_1)进行更新:

    (frac{partial loss}{partial w_3} = frac{partial loss}{partial f_3} frac{partial f_3}{partial w_3} = frac{partial loss}{partial f_3} sigma^{'}(w_3f_2+b_3)f_2)

    (frac{partial loss}{partial w_2} = frac{partial loss}{partial f_3} frac{partial f_3}{partial f_2} frac{partial f_2}{partial w_2} = frac{partial loss}{partial f_3} sigma^{'}(w_3f_2+b_3)w_3 sigma^{'}(w_2f_1+b_2)f_1)

    (frac{partial loss}{partial w_1} = frac{partial loss}{partial f_3} frac{partial f_3}{partial f_2} frac{partial f_2}{partial f_1} frac{partial f_1}{partial w_1} = frac{partial loss}{partial f_3} sigma^{'}(w_3f_2+b_3)w_3 sigma^{'}(w_2f_1+b_2)w_2 sigma^{'}(w_1x+b_1)x)

    根据上面规律,我们可以把(x)写成(f_0),当有n-1层隐层时,(f_n)是输出,如果要求(w_l)也就是第(l)层的权重,反向传播中涉及的偏导计算为:

    (frac{partial loss}{partial w_l } = frac{partial loss }{partial {f_n} } prod_{i=l}^{n}sigma^{'}(w_if_{i-1} + b_i)prod_{i=l+1}^{n}w_i f_{l-1})

    上面这个式子就是我们要推导的核心!

    当梯度反向传播到第(l)层的时候,我们用上述公式计算偏导,根据链式法则,上面用大括号括起来的就是累乘项,其中前半部分是关于激活函数的导数的累乘,后半部分是关于权重值的累乘。我们知道,激活函数比如sigmoid函数,其导数的取值范围是((0, frac{1}{4}]),是恒小于1的,当网络层数很深的时候,多个小于1的数进行累乘,结果是趋向于0的,也就是说此时,梯度反向传播的时候,根据参数更新公式( heta := heta - alpha cdot frac{partial loss}{ heta}),偏导部分的取值趋于0,那么该参数得不到更新,也就出现了我们说的梯度消失现象。

    另外,我们也注意到,大括号的后半部分是关于权重值的累乘,当我们初始化权值很大的时候,多个大于1的数累乘,结果是(+infty),此时就出现了梯度爆炸现象。

  • 相关阅读:
    Java Object-Oriented:day11 【 Inner Classes 】
    Java Object-Oriented:day11 【 final 】
    Java Object-Oriented:day10 【 Inheritance】
    Java Object-Oriented:day10 【 Interfaces】
    深入剖析Kubernetes学习笔记:StatefulSet-MySQL集群(20)
    Java Object-Oriented:day09 【 Inheritance、Abstract Classes】
    Java Object-Oriented:day09 【 Inheritance、super、this】
    kubernetes云平台管理实战:常用命令大全(十二)
    Java API:day08 【static、Arrays、Math】
    Java API:day08 【String】
  • 原文地址:https://www.cnblogs.com/Elaine-DWL/p/11239969.html
Copyright © 2011-2022 走看看