zoukankan      html  css  js  c++  java
  • 『TensorFlow』批处理类

    『教程』Batch Normalization 层介绍

     基础知识

    下面有莫凡的对于批处理的解释:

    fc_mean,fc_var = tf.nn.moments(
        Wx_plus_b,
        axes=[0],
        # 想要 normalize 的维度, [0] 代表 batch 维度
        # 如果是图像数据, 可以传入 [0, 1, 2], 相当于求[batch, height, width] 的均值/方差, 注意不要加入 channel 维度
    )
    scale = tf.Variable(tf.ones([out_size]))
    shift = tf.Variable(tf.zeros([out_size]))
    epsilon = 0.001
    Wx_plus_b = tf.nn.batch_normalization(Wx_plus_b,fc_mean,fc_var,shift,scale,epsilon)
    # 上面那一步, 在做如下事情:
    # Wx_plus_b = (Wx_plus_b - fc_mean) / tf.sqrt(fc_var + 0.001)
    # Wx_plus_b = Wx_plus_b * scale + shift
    

    tf.contrib.layers.batch_norm:封装好的批处理类

    class batch_norm():
        '''batch normalization层'''
    
        def __init__(self, epsilon=1e-5,
                     momentum=0.9, name='batch_norm'):
            '''
            初始化
            :param epsilon:    防零极小值
            :param momentum:   滑动平均参数
            :param name:       节点名称
            '''
            with tf.variable_scope(name):
                self.epsilon = epsilon
                self.momentum = momentum
                self.name = name
    
        def __call__(self, x, train=True):
            # 一个封装了的会在内部调用batch_normalization进行正则化的高级接口
            return tf.contrib.layers.batch_norm(x,
                                                decay=self.momentum,        # 滑动平均参数
                                                updates_collections=None,
                                                epsilon=self.epsilon,
                                                scale=True,
                                                is_training=train,          # 影响滑动平均
                                                scope=self.name)
    

    1.

    Note: when training, the moving_mean and moving_variance need to be updated.
        By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
        need to be added as a dependency to the `train_op`. For example:
        
        ```python
          update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
          with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss)
        ```
        
        One can set updates_collections=None to force the updates in place, but that
        can have a speed penalty, especially in distributed settings.

    2.

    is_training: Whether or not the layer is in training mode. In training mode
            it would accumulate the statistics of the moments into `moving_mean` and
            `moving_variance` using an exponential moving average with the given
            `decay`. When it is not in training mode then it would use the values of
            the `moving_mean` and the `moving_variance`.

    tf.nn.batch_normalization:原始接口封装使用

    实际上tf.contrib.layers.batch_norm对于tf.nn.moments和tf.nn.batch_normalization进行了一次封装,这个类又进行了一次封装(主要是制订了一部分默认参数),实际操作时可以仅仅使用tf.contrib.layers.batch_norm函数,它已经足够方便了。

    添加了滑动平均处理之后,也就是不使用封装,直接使用tf.nn.moments和tf.nn.batch_normalization实现的batch_norm函数:

    def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):
        with tf.variable_scope(scope):
            # beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
            # gamma = tf.get_variable(name='gamma', shape=[n_out],
            #                         initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
            batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')
            ema = tf.train.ExponentialMovingAverage(decay=decay)
    
            def mean_var_with_update():
                ema_apply_op = ema.apply([batch_mean,batch_var])
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(batch_mean),tf.identity(batch_var)
                    # identity之后会把Variable转换为Tensor并入图中,
                    # 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制
    
            mean,var = tf.cond(phase_train,
                               mean_var_with_update,
                               lambda: (ema.average(batch_mean),ema.average(batch_var)))
           normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
        return normed
    

    另一种将滑动平均展开了的方式,

    def batch_norm(x, size, training, decay=0.999):
        beta = tf.Variable(tf.zeros([size]), name='beta')
        scale = tf.Variable(tf.ones([size]), name='scale')
        pop_mean = tf.Variable(tf.zeros([size]))
        pop_var = tf.Variable(tf.ones([size]))
        epsilon = 1e-3
    
        batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2])
        train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
        train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
    
        def batch_statistics():
            with tf.control_dependencies([train_mean, train_var]):
                return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm')
    
        def population_statistics():
            return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm')
    
    return tf.cond(training, batch_statistics, population_statistics)
    

     注, tf.cond:流程控制,参数一True,则执行参数二的函数,否则执行参数三函数。

  • 相关阅读:
    Java数组排序和搜索
    JDBC排序数据实例
    JDBC Like子句实例
    JDBC WHERE子句条件实例
    JDBC删除数据实例
    JDBC更新数据实例
    JDBC查询数据实例
    JDBC插入数据实例
    JDBC删除表实例
    JDBC创建表实例
  • 原文地址:https://www.cnblogs.com/hellcat/p/7380022.html
Copyright © 2011-2022 走看看