zoukankan      html  css  js  c++  java
  • BN 详解和使用Tensorflow实现(参数理解)

    Tensorflow   BN具体实现(多种方式):

    理论知识(参照大佬):https://blog.csdn.net/hjimce/article/details/50866313

    补充知识:

    ① tf.nn.moments  这个函数的输出就是BN需要的mean和variance。

    方式1:

    tf.nn.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None):原始接口封装使用

     x
    ·mean moments方法的输出之一
    ·variance moments方法的输出之一
    ·offset BN需要学习的参数
    ·scale BN需要学习的参数
    ·variance_epsilon 归一化时防止分母为0加的一个常量

    实现代码:

     1 import tensorflow as tf
     2 
     3 # 实现Batch Normalization
     4 def bn_layer(x,is_training,name='BatchNorm',moving_decay=0.9,eps=1e-5):
     5     # 获取输入维度并判断是否匹配卷积层(4)或者全连接层(2)
     6     shape = x.shape
     7     assert len(shape) in [2,4]
     8 
     9     param_shape = shape[-1]
    10     with tf.variable_scope(name):
    11         # 声明BN中唯一需要学习的两个参数,y=gamma*x+beta
    12         gamma = tf.get_variable('gamma',param_shape,initializer=tf.constant_initializer(1))
    13         beta  = tf.get_variable('beat', param_shape,initializer=tf.constant_initializer(0))
    14 
    15         # 计算当前整个batch的均值与方差
    16         axes = list(range(len(shape)-1))
    17         batch_mean, batch_var = tf.nn.moments(x,axes,name='moments')
    18 
    19         # 采用滑动平均更新均值与方差
    20         ema = tf.train.ExponentialMovingAverage(moving_decay)
    21 
    22         def mean_var_with_update():
    23             ema_apply_op = ema.apply([batch_mean,batch_var])
    24             with tf.control_dependencies([ema_apply_op]):
    25                 return tf.identity(batch_mean), tf.identity(batch_var)
    26 
    27         # 训练时,更新均值与方差,测试时使用之前最后一次保存的均值与方差
    28         mean, var = tf.cond(tf.equal(is_training,True),mean_var_with_update,
    29                 lambda:(ema.average(batch_mean),ema.average(batch_var)))
    30 
    31         # 最后执行batch normalization
    32         return tf.nn.batch_normalization(x,mean,var,beta,gamma,eps)

    方式2:

    tf.contrib.layers.batch_norm:封装好的批处理类

    实际上tf.contrib.layers.batch_norm对于tf.nn.moments和tf.nn.batch_normalization进行了一次封装

    参数:

    1 inputs: 输入

    2 decay :衰减系数。合适的衰减系数值接近1.0,特别是含多个9的值:0.999,0.99,0.9。如果训练集表现很好而验证/测试集表现得不好,选择

    小的系数(推荐使用0.9)。如果想要提高稳定性,zero_debias_moving_mean设为True

    3 center:如果为True,有beta偏移量;如果为False,无beta偏移量

    4 scale:如果为True,则乘以gamma。如果为False,gamma则不使用。当下一层是线性的时(例如nn.relu),由于缩放可以由下一层完成,

    所以可以禁用该层。

    5 epsilon:避免被零除

    6 activation_fn:用于激活,默认为线性激活函数

    7 param_initializers : beta, gamma, moving mean and moving variance的优化初始化

    8 param_regularizers : beta and gamma正则化优化

    9 updates_collections :Collections来收集计算的更新操作。updates_ops需要使用train_op来执行。如果为None,则会添加控件依赖项以

    确保更新已计算到位。

    10 is_training:图层是否处于训练模式。在训练模式下,它将积累转入的统计量moving_mean并 moving_variance使用给定的指数移动平均值 decay。当它不是在训练模式,那么它将使用的数值moving_mean和moving_variance。
    11 scope:可选范围variable_scope
    注意:训练时,需要更新moving_mean和moving_variance。默认情况下,更新操作被放入tf.GraphKeys.UPDATE_OPS,所以需要添加它们作为依赖项train_op。例如:

      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  with tf.control_dependencies(update_ops):    train_op = optimizer.minimize(loss)

    可以将updates_collections = None设置为强制更新,但可能会导致速度损失,尤其是在分布式设置中。

    实现代码:

    1 import tensorflow as tf
    2 
    3 def batch_norm(x,epsilon=1e-5, momentum=0.9,train=True, name="batch_norm"):
    4     with tf.variable_scope(name):
    5         epsilon = epsilon
    6         momentum = momentum
    7         name = name
    8     return tf.contrib.layers.batch_norm(x, decay=momentum, updates_collections=None, epsilon=epsilon,
    9                                         scale=True, is_training=train,scope=name)

    BN一般放哪一层?

    BN层的设定一般是按照conv->bn->scale->relu的顺序来形成一个block

    训练和测试时 BN的区别???

    bn层训练的时候,基于当前batch的mean和std调整分布;当测试的时候,也就是测试的时候,基于全部训练样本的mean和std调整分布

    所以,训练的时候需要让BN层工作,并且保存BN层学习到的参数。测试的时候加载训练得到的参数来重构测试集。

  • 相关阅读:
    SQL Server 中的事务与事务隔离级别以及如何理解脏读, 未提交读,不可重复读和幻读产生的过程和原因
    微软BI 之SSIS 系列
    微软BI 之SSIS 系列
    微软BI 之SSIS 系列
    微软BI 之SSIS 系列
    微软BI 之SSIS 系列
    微软BI 之SSAS 系列
    微软BI 之SSRS 系列
    微软BI 之SSRS 系列
    配置 SQL Server Email 发送以及 Job 的 Notification通知功能
  • 原文地址:https://www.cnblogs.com/WSX1994/p/10949079.html
Copyright © 2011-2022 走看看