zoukankan      html  css  js  c++  java
  • CS231n 2016 通关 第五、六章 Batch Normalization 作业

    BN层在实际中应用广泛。

    上一次总结了使得训练变得简单的方法,比如SGD+momentum RMSProp Adam,BN是另外的方法。

    cell 1 依旧是初始化设置

    cell 2 读取cifar-10数据

    cell 3 BN的前传

     1 # Check the training-time forward pass by checking means and variances
     2 # of features both before and after batch normalization
     3 
     4 # Simulate the forward pass for a two-layer network
     5 N, D1, D2, D3 = 200, 50, 60, 3
     6 X = np.random.randn(N, D1)
     7 W1 = np.random.randn(D1, D2)
     8 W2 = np.random.randn(D2, D3)
     9 a = np.maximum(0, X.dot(W1)).dot(W2)
    10 
    11 print 'Before batch normalization:'
    12 print '  means: ', a.mean(axis=0)
    13 print '  stds: ', a.std(axis=0)
    14 
    15 # Means should be close to zero and stds close to one
    16 print 'After batch normalization (gamma=1, beta=0)'
    17 a_norm, _ = batchnorm_forward(a, np.ones(D3), np.zeros(D3), {'mode': 'train'})
    18 print '  mean: ', a_norm.mean(axis=0)
    19 print '  std: ', a_norm.std(axis=0)
    20 
    21 # Now means should be close to beta and stds close to gamma
    22 gamma = np.asarray([1.0, 2.0, 3.0])
    23 beta = np.asarray([11.0, 12.0, 13.0])
    24 a_norm, _ = batchnorm_forward(a, gamma, beta, {'mode': 'train'})
    25 print 'After batch normalization (nontrivial gamma, beta)'
    26 print '  means: ', a_norm.mean(axis=0)
    27 print '  stds: ', a_norm.std(axis=0)

      相应的核心代码:

     1     buf_mean = np.mean(x, axis=0)
     2     buf_var = np.var(x, axis=0)
     3     x_hat = x - buf_mean
     4     x_hat = x_hat / (np.sqrt(buf_var + eps))
     5 
     6     out = gamma * x_hat + beta
     7     #running_mean = momentum * running_mean + (1 - momentum) * sample_mean
     8     #running_var = momentum * running_var + (1 - momentum) * sample_var
     9     running_mean = momentum * running_mean + (1- momentum) * buf_mean
    10     running_var = momentum * running_var + (1 - momentum) * buf_var   

      running_mean  running_var 是在test时使用的,test时不再另外计算均值和方差。

      test 时的前传核心代码:

    1 x_hat = x - running_mean
    2 x_hat = x_hat / (np.sqrt(running_var + eps))
    3 out = gamma * x_hat + beta

    cell 5 BN反向传播

      通过反向传播,计算beta gamma等参数。

      核心代码:

     1   dx_hat = dout * cache['gamma'] 
     2   dgamma = np.sum(dout * cache['x_hat'], axis=0)
     3   dbeta = np.sum(dout, axis=0)
     4   #x_hat = x - buf_mean
     5   #x_hat = x_hat / (np.sqrt(buf_var + eps))
     6   t1 = cache['x'] - cache['mean']
     7   t2 = (-0.5)*((cache['var'] + cache['eps'])**(-1.5))
     8   t1 = t1 * t2
     9   d_var = np.sum(dx_hat * t1, axis=0)
    10 
    11   tmean1 = (-1)*((cache['var'] + cache['eps'])**(-0.5))
    12   d_mean = np.sum(dx_hat * tmean1, axis=0)
    13 
    14   tmean1 = (-1)*tmean1
    15   tx1 =   dx_hat * tmean1
    16   tx2 = d_mean * (1.0 / float(N))
    17   tx3 = d_var * (2 * (cache['x'] - cache['mean']) / N)
    18   dx = tx1 + tx2 + tx3

    cell 9 BN与其他层结合

      形成的结构:   {affine - [batch norm] - relu - [dropout]} x (L - 1) - affine - softmax

      原理依旧。

    之后是对cell 9 的模型,对cifar-10数据训练。

    值得注意的是:

      使用BN后,正则项与dropout层的需求降低。可以使用较高的学习率加快模型收敛。

    附:通关CS231n企鹅群:578975100 validation:DL-CS231n 

  • 相关阅读:
    多线程 信号量
    sql在不同数据库查询前几条数据
    Office Outlook同步 很奇怪的BUG
    搜索小技巧整理
    想做一个权限管理插件
    ibatis和Castle学习历程
    查找存储过程中的错误位置
    VS2005项目模版丢失解决方案及VS2005项目模版查找原理
    C# 邮件发送接收
    数据库优化整合
  • 原文地址:https://www.cnblogs.com/wangxiu/p/5689807.html
Copyright © 2011-2022 走看看