zoukankan      html  css  js  c++  java
  • BN 详解和使用Tensorflow实现(参数理解)

    Tensorflow   BN具体实现(多种方式):

    理论知识(参照大佬):https://blog.csdn.net/hjimce/article/details/50866313

    补充知识:

    ① tf.nn.moments  这个函数的输出就是BN需要的mean和variance。

    方式1:

    tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None):原始接口封装使用

     x
    ·mean moments方法的输出之一
    ·variance moments方法的输出之一
    ·offset BN需要学习的参数
    ·scale BN需要学习的参数
    ·variance_epsilon 归一化时防止分母为0加的一个常量

    实现代码:

     1 import tensorflow as tf
     2 
     3 # 实现Batch Normalization
     4 def bn_layer(x,is_training,name='BatchNorm',moving_decay=0.9,eps=1e-5):
     5     # 获取输入维度并判断是否匹配卷积层(4)或者全连接层(2)
     6     shape = x.shape
     7     assert len(shape) in [2,4]
     8 
     9     param_shape = shape[-1]
    10     with tf.variable_scope(name):
    11         # 声明BN中唯一需要学习的两个参数,y=gamma*x+beta
    12         gamma = tf.get_variable('gamma',param_shape,initializer=tf.constant_initializer(1))
    13         beta  = tf.get_variable('beat', param_shape,initializer=tf.constant_initializer(0))
    14 
    15         # 计算当前整个batch的均值与方差
    16         axes = list(range(len(shape)-1))
    17         batch_mean, batch_var = tf.nn.moments(x,axes,name='moments')
    18 
    19         # 采用滑动平均更新均值与方差
    20         ema = tf.train.ExponentialMovingAverage(moving_decay)
    21 
    22         def mean_var_with_update():
    23             ema_apply_op = ema.apply([batch_mean,batch_var])
    24             with tf.control_dependencies([ema_apply_op]):
    25                 return tf.identity(batch_mean), tf.identity(batch_var)
    26 
    27         # 训练时,更新均值与方差,测试时使用之前最后一次保存的均值与方差
    28         mean, var = tf.cond(tf.equal(is_training,True),mean_var_with_update,
    29                 lambda:(ema.average(batch_mean),ema.average(batch_var)))
    30 
    31         # 最后执行batch normalization
    32         return tf.nn.batch_normalization(x,mean,var,beta,gamma,eps)

    方式2:

    tf.contrib.layers.batch_norm:封装好的批处理类

    实际上tf.contrib.layers.batch_norm对于tf.nn.moments和tf.nn.batch_normalization进行了一次封装

    参数:

    1 inputs: 输入

    2 decay :衰减系数。合适的衰减系数值接近1.0,特别是含多个9的值:0.999,0.99,0.9。如果训练集表现很好而验证/测试集表现得不好,选择

    小的系数(推荐使用0.9)。如果想要提高稳定性,zero_debias_moving_mean设为True

    3 center:如果为True,有beta偏移量;如果为False,无beta偏移量

    4 scale:如果为True,则乘以gamma。如果为False,gamma则不使用。当下一层是线性的时(例如nn.relu),由于缩放可以由下一层完成,

    所以可以禁用该层。

    5 epsilon:避免被零除

    6 activation_fn:用于激活,默认为线性激活函数

    7 param_initializers : beta, gamma, moving mean and moving variance的优化初始化

    8 param_regularizers : beta and gamma正则化优化

    9 updates_collections :Collections来收集计算的更新操作。updates_ops需要使用train_op来执行。如果为None,则会添加控件依赖项以

    确保更新已计算到位。

    10 is_training:图层是否处于训练模式。在训练模式下,它将积累转入的统计量moving_mean并 moving_variance使用给定的指数移动平均值 decay。当它不是在训练模式,那么它将使用的数值moving_mean和moving_variance。
    11 scope:可选范围variable_scope
    注意:训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作被放入tf.GraphKeys.UPDATE_OPS,所以需要添加它们作为依赖项train_op。例如:

      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  with tf.control_dependencies(update_ops):    train_op = optimizer.minimize(loss)

    可以将updates_collections = None设置为强制更新,但可能会导致速度损失,尤其是在分布式设置中。

    实现代码:

    1 import tensorflow as tf
    2 
    3 def batch_norm(x,epsilon=1e-5, momentum=0.9,train=True, name="batch_norm"):
    4     with tf.variable_scope(name):
    5         epsilon = epsilon
    6         momentum = momentum
    7         name = name
    8     return tf.contrib.layers.batch_norm(x, decay=momentum, updates_collections=None, epsilon=epsilon,
    9                                         scale=True, is_training=train,scope=name)

    BN一般放哪一层?

    BN层的设定一般是按照conv->bn->scale->relu的顺序来形成一个block

    训练和测试时 BN的区别???

    bn层训练的时候,基于当前batch的mean和std调整分布;当测试的时候,也就是测试的时候,基于全部训练样本的mean和std调整分布

    所以,训练的时候需要让BN层工作,并且保存BN层学习到的参数。测试的时候加载训练得到的参数来重构测试集。

  • 相关阅读:
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
    xgqfrms™, xgqfrms® : xgqfrms's offical website of GitHub!
    你会卖掉自己的网上信息吗?大数据可能根本不属于你
    机器学习——TensorFLow实战房价预测
    数据库运作实践三三之歌(秘制口诀)
    1000行MySQL学习笔记,收藏版!
    吐血整理深度学习入门路线及导航【教学视频+大神博客+书籍整理】+【资源页】(2019年已经最后一个月了,你还不学深度学习吗???)
    Ubuntu Snap 简述
    参数传递
  • 原文地址:https://www.cnblogs.com/WSX1994/p/10949079.html
Copyright © 2011-2022 走看看