zoukankan      html  css  js  c++  java
  • tensorflow 下的滑动平均模型 —— tf.train.ExponentialMovingAverage

    在采用随机梯度下降算法训练神经网络时,使用 tf.train.ExponentialMovingAverage 滑动平均操作的意义在于提高模型在测试数据上的健壮性(robustness)

    tensorflow 下的 tf.train.ExponentialMovingAverage 需要提供一个衰减率(decay)。该衰减率用于控制模型更新的速度。该衰减率用于控制模型更新的速度,ExponentialMovingAverage 对每一个(待更新训练学习的)变量(variable)都会维护一个影子变量(shadow variable)。影子变量的初始值就是这个变量的初始值,

    shadow_variable=decay×shadow_variable+(1decay)×variable

    由上述公式可知, decay 控制着模型更新的速度,越大越趋于稳定。实际运用中,decay 一般会设置为十分接近 1 的常数(0.99或0.999)。为了使得模型在训练的初始阶段更新得更快,ExponentialMovingAverage 还提供了 num_updates 参数来动态设置 decay 的大小:

    decay=min{decay,1+num_updates10+num_updates}

    import tensorflow as tf
    
    
    v1 =tf.Variable(dtype=tf.float32, initial_value=0.)
    decay = .99
    num_updates = tf.Variable(0, trainable=False)
    ema = tf.train.ExponentialMovingAverage(decay=decay, num_updates=num_updates)
    
    update_var_list = [v1]      # 定义更新变量列表
    ema_apply = ema.apply(update_var_list)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run([v1, ema.average(v1)]))             
                                                    # [0.0, 0.0](此时 num_updates = 0 ⇒ decay = .1, ),shadow_variable = variable = 0.
    
        sess.run(tf.assign(v1, 5))
        sess.run(ema_apply)
        print(sess.run([v1, ema.average(v1)]))
                                                    # 此时,num_updates = 0 ⇒ decay =.1,  v1 = 5;
                                                    # shadow_variable = 0.1 * 0 + 0.9 * 5 = 4.5 ⇒ variable
        sess.run(tf.assign(num_updates, 10000))
        sess.run(tf.assign(v1, 10))
        sess.run(ema_apply)
        print(sess.run([v1, ema.average(v1)]))
                                                    # decay = .99,
                                                    # shadow_variable = 0.99 * 4.5 + .01*10 ⇒ 4.555
        sess.run(ema_apply)
        print(sess.run([v1, ema.average(v1)]))
                                                    # decay = .99
                                                    # shadow_variable = .99*4.555 + .01*10 = 4.609  
  • 相关阅读:
    JSON基础(Java)
    美式英语音标词对照表
    network adapter
    debian网络静态ip配置
    apt --fix-broken install
    CA certificate
    debian使用过程中常见的问题
    将普通用户添加到sudo
    nano
    jenkins安装和使用
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9421617.html
Copyright © 2011-2022 走看看