zoukankan      html  css  js  c++  java
  • 3.1 Tensorflow: 批标准化(Batch Normalization)

    BN 简介

    背景

    批标准化(Batch Normalization )简称BN算法,是为了克服神经网络层数加深导致难以训练而诞生的一个算法。根据ICS理论,当训练集的样本数据和目标样本集分布不一致的时候,训练得到的模型无法很好的泛化。

    而在神经网络中,每一层的输入在经过层内操作之后必然会导致与原来对应的输入信号分布不同,,并且前层神经网络的增加会被后面的神经网络不对的累积放大。这个问题的一个解决思路就是根据训练样本与目标样本的比例对训练样本进行一个矫正,而BN算法(批标准化)则可以用来规范化某些层或者所有层的输入,从而固定每层输入信号的均值与方差。

    使用方法

    批标准化一般用在非线性映射(激活函数)之前,对y= Wx + b进行规范化,是结果(输出信号的各个维度)的均值都为0,方差为1,让每一层的输入有一个稳定的分布会有利于网络的训练

    在神经网络收敛过慢或者梯度爆炸时的那个无法训练的情况下都可以尝试

    优点

    • 减少了参数的人为选择,可以取消dropout和L2正则项参数,或者采取更小的L2正则项约束参数
    • 减少了对学习率的要求
    • 可以不再使用局部响应归一化了,BN本身就是归一化网络(局部响应归一化-AlexNet)
    • 更破坏原来的数据分布,一定程度上缓解过拟合

    计算公式

    这里写图片描述

    其过程类似于归一化但是又不同.

    参考

    BN原理的详细参考建议:BN学习笔记:点击这里

    BN with TF

    组成部分

    BN在TensorFlow中主要有两个函数:tf.nn.moments以及tf.nn.batch_normalization,两者需要配合使用,前者用来返回均值和方差,后者用来进行批处理(BN)

    tf.nn.moments

    TensorFlow中的函数

    moments(
        x,
        axes,
        shift=None,
        name=None,
        keep_dims=False
    )
    
      Returns:
        Two `Tensor` objects: `mean` and `variance`.

    其中参数 x 为要传递的tensor,axes是个int数组,传递要进行计算的维度,返回值是两个张量: mean and variance,我们需要利用这个函数计算出BN算法需要的前两项,公式见前面的原理部分
    参考代码如下:

    # 计算Wx_plus_b 的均值与方差,其中axis = [0] 表示想要标准化的维度
    img_shape= [128, 32, 32, 64]
    Wx_plus_b = tf.Variable(tf.random_normal(img_shape))
    axis = list(range(len(img_shape)-1)) # [0,1,2] 
    wb_mean, wb_var = tf.nn.moments(Wx_plus_b, axis)

    运行结果,因为初始的数据是随机的,所以每次的运行结果并不一致:

    *** wb_mean ***
    [  1.05310767e-03   1.16801530e-03   4.95071337e-03  -1.50891789e-03
      -2.95298663e-03  -2.07848335e-03  -3.81800164e-05  -3.11688287e-03
       3.26496479e-03  -2.68524280e-04  -2.08893605e-03  -3.05374013e-03
       1.43721583e-03  -3.61034041e-03  -3.03616724e-03  -1.10225368e-03
       6.14093244e-03  -1.37914100e-03  -1.13333750e-03   3.53972078e-03
      -1.48577197e-03   1.04353309e-03   3.27868876e-03  -1.40919012e-03
       3.09609319e-03   1.98166977e-04  -5.25404140e-03  -6.03850756e-04
      -1.04614964e-03   2.90997117e-03   5.78491192e-04  -4.97420435e-04
       3.03052540e-04   2.46527663e-04  -4.70882794e-03   2.79057049e-03
      -1.98713480e-03   4.13944060e-03  -4.80978837e-04  -3.90357309e-04
       9.11145413e-04  -4.80215019e-03   6.26503082e-04  -2.76877987e-03
       3.79961479e-04   5.36157866e-04  -2.12549698e-03  -5.41620655e-03
      -1.93006988e-03  -8.54363534e-05   4.97094262e-03  -2.45843385e-03
       4.16610064e-03   2.44746287e-03  -4.15429426e-03  -6.64028199e-03
       2.56747357e-03  -1.63110415e-03  -1.53350492e-03  -7.66420271e-04
      -1.81624549e-03   2.16634944e-03   1.74984348e-03  -4.17272677e-04]
    *** wb_var ***
    [ 0.99813616  0.9983741   1.00014114  1.0012747   0.99496585  1.00168002
      1.00439012  0.99607879  1.00104094  0.99969071  1.01024568  0.99614906
      1.00092578  0.99977148  1.00447345  0.99580348  0.99797201  0.99119431
      1.00352168  0.9958936   0.99980813  1.00598109  1.00050855  0.99667317
      0.99352562  1.0036608   0.99794698  0.99324805  0.99862647  0.99930048
      0.99658304  1.00278556  0.99731135  1.00254881  0.99352133  1.00371397
      1.00258803  1.00388253  1.00404358  0.99454063  0.99434716  1.00087452
      1.00818515  1.00019705  0.99542576  1.00410056  0.99707311  1.00215423
      1.00199771  0.99394888  0.9973973   1.00197709  0.99835181  0.99944276
      0.99977624  0.99892712  0.99871159  0.99913275  1.00471914  1.00210452
      0.99568754  0.99547535  0.99983472  1.00523198]**重点内容**

    我们已经假设图片的shape[128, 32, 32, 64],它的运算方式如图:

    tf.nn.moments()的运算方式

    tf.nn.batch_normalization

    TensorFlow中的函数

    batch_normalization(
        x,
        mean,
        variance,
        offset,
        scale,
        variance_epsilon,
        name=None
    )

    其中x为输入的tensor,mean,variance由moments()求出,而offset,scale一般分别初始化为0和1,variance_epsilon一般设为比较小的数字即可,参考代码如下:

    scale = tf.Variable(tf.ones([64]))
    offset = tf.Variable(tf.zeros([64]))
    variance_epsilon = 0.001
    Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, wb_mean, wb_var, offset, scale, variance_epsilon)
    
    # 根据公式我们也可以自己写一个
    Wx_plus_b1 = (Wx_plus_b - wb_mean) / tf.sqrt(wb_var + variance_epsilon)
    Wx_plus_b1 = Wx_plus_b1 * scale + offset
    # 因为底层运算方式不同,实际上自己写的最后的结果与直接调用tf.nn.batch_normalization获取的结果并不一致

    运行结果,因为初始的数据是随机的,所以每次的运行结果并不一致,但是本例子中的计算差异始终存在:

    # 这里我们只需比较前两的矩阵即可发现存在的数值差异
    [[[[  3.32006335e-01  -1.00865233e+00   4.68401730e-01 ...,
         -1.31523395e+00  -1.13771069e+00  -2.06656289e+00]
       [  1.92613199e-01  -1.41019285e-01   1.03402412e+00 ...,
          1.66336447e-01   2.34183773e-01   1.18540943e+00]
       [ -7.14844346e-01  -1.56187916e+00  -8.09686005e-01 ...,
         -4.23679769e-01  -4.32125211e-01  -3.35091174e-01]
       ..., 
    
    [[[[  3.31096262e-01  -1.01013660e+00   4.63186830e-01 ...,
         -1.31972826e+00  -1.13898540e+00  -2.05973744e+00]
       [  1.91642866e-01  -1.42231822e-01   1.02848673e+00 ...,
          1.64460197e-01   2.32336998e-01   1.18214881e+00]
       [ -7.16206789e-01  -1.56353664e+00  -8.14172268e-01 ...,
         -4.26598638e-01  -4.33694094e-01  -3.33635926e-01]

    完整的代码

    # - * - coding: utf - 8 -*-
    #
    # 作者:田丰(FontTian)
    # 创建时间:'2017/8/2'
    # 邮箱:fonttian@Gmaill.com
    # CSDN:http://blog.csdn.net/fontthrone
    
    import tensorflow as tf
    import os
    
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    # 计算Wx_plus_b 的均值与方差,其中axis = [0] 表示想要标准化的维度
    img_shape = [128, 32, 32, 64]
    Wx_plus_b = tf.Variable(tf.random_normal(img_shape))
    axis = list(range(len(img_shape) - 1))
    wb_mean, wb_var = tf.nn.moments(Wx_plus_b, axis)
    
    scale = tf.Variable(tf.ones([64]))
    offset = tf.Variable(tf.zeros([64]))
    variance_epsilon = 0.001
    Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b, wb_mean, wb_var, offset, scale, variance_epsilon)
    
    Wx_plus_b1 = (Wx_plus_b - wb_mean) / tf.sqrt(wb_var + variance_epsilon)
    Wx_plus_b1 = Wx_plus_b1 * scale + offset
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
    
        print('*** wb_mean ***')
        print(sess.run(wb_mean))
        print('*** wb_var ***')
        print(sess.run(wb_var))
        print('*** Wx_plus_b ***')
        print(sess.run(Wx_plus_b))
        print('**** Wx_plus_b1 ****')
        print(sess.run(Wx_plus_b1))
  • 相关阅读:
    Spring MVC 拦截器
    spring中MultiActionController的数据绑定
    Hibernate多对多配置
    hibernate实体类配置文件问题(字段使用默认值)
    HibernateTemplate类的使用 (转)
    javascript小笔记(一)
    spring整合hibernate(2)
    Sina AppEngine 的bug
    找工作
    天下武功唯快不破
  • 原文地址:https://www.cnblogs.com/fonttian/p/9162800.html
Copyright © 2011-2022 走看看