zoukankan      html  css  js  c++  java
  • Batch Normalization 反向传播(backpropagation )公式的推导

    What does the gradient flowing through batch normalization looks like ?

    反向传播梯度下降权值参数更新公式的推导全依赖于复合函数求梯度时的链式法则

    1. Batch Normalization

    给定输入样本 xRN×D,经过一个神经元个数为 H 的隐层,负责连接输入层和隐层的权值矩阵 wRD×H,以及偏置向量 bRH

    Batch Normalization 的过程如下:

    • 仿射变换(affine transformation)

      h=XW+b

      显然 hRN×H

    • batch normalization 变换:

      y=γh^+β

      其中 γ,β 是待学习的参数,h^h 去均值和方差归一化的形式:

      h^=(hμ)(σ2+ϵ)1/2

      进一步其标量形式如下:

      hˆkl=(hklμl)(σ2l+ϵ)1/2

      l={1,,H}μσ 分别是对矩阵 hRN×H 的各个属性列,求均值和方差,最终构成的均值向量和方差向量。

      μl=1Nphpl,σ2l=1Np(hplμl)2

    2. Lh,Lγ,Lβ 的计算

    首先我们来看损失函数 L 关于隐层输入偏导的计算:

    dLdh=dLdh11..dLdhN1..dLdhkl...dLdh1H..dLdhNH.

    又由于:

    h=XW+b,hh^,h^y

    由链式法则可知:

    Lhij=k,lLyklyklh^klh^klhij

    显然其中 yklh^kl=γl

    又由于:

    hˆkl=(hklμl)(σ2l+ϵ)1/2,μl=1Nphpl,σ2l=1Np(hplμl)2

    所以:

    dh^kldhij=(δikδjl1Nδjl)(σ2l+ϵ)1/212(hklμl)dσ2ldhij(σ2l+ϵ)3/2

    根据 σ2lhij 的计算公式可知:

    dσ2ldhij====2Np(hplμl)(δipδjl1Nδjl)p=12N(hilμl)δjl2Nδjl1Np(hplμl)2N(hilμl)δjl2Nδjl1Nphplμl02N(hilμl)δjl

  • 相关阅读:
    animate()的使用
    jQuery UI中datepicker()的使用
    newsletter 在不同邮箱中发送的问题
    用户评价 打星等级 效果 jQuery
    SpringBoot整合WebSocket
    SpringBoot之自定义拦截器
    SpringBoot整合MyBatis
    EF实体框架简单原理及基本增删改查用法(上)
    数据库内容导出到Excel
    Asp.Net页面生命周期
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9421692.html
Copyright © 2011-2022 走看看