转自 某大佬的公众号
为什么要使用滑动平均模型?
通过使用滑动平均我们可以使神经网络模型在测试数据上更健壮,在使用随机梯度下降算法训练神经网络时,通过滑动平均模型可以在一定程度上提高最终模型在测试数据上的表现:
它通过控制衰减率(decay)来控制参数更新前后之间的差距,从而达到减缓参数的变化幅度的目的(如,参数更新前是5,更新后的值是4,通过滑动平均模型之后,参数的值会在4到5之间),如果参数更新前后的值保持不变,通过滑动平均模型之后,参数的值仍然保持不变。
TensorFlow中的ExponentialMovingAverage函数实现了滑动平均模型:
tf.train.ExponentialMovingAverage(decay,num_updates=None,zero_debias=False,name="ExponentialMovingAverage")
其中decay为衰减率,num_updates为可选参数,你可以默认为None,如果设置了num_updates,那么这个参数就代表了模型更新的次数。如果在ExponentialMovingAverage函数中初始化了num_updates参数,那么每次使用的衰减率将会按照如下公式更新:
衰减率更新公式:
decay = min{init_decay , (1 + num_update) / (10 + num_update)}
可见随着 num_update 更新次数的增加,(1 + num_update) / (10 + num_update 这一项的计算结果越接近1
此时原模型参数按照以下公式更新:
shadow_variable = decay * shadow_variable + (1 - decay) * variable
其中 shadow_variable 为变量更新前的数值,我们叫它影子变量,variable为变量更新后的数值,上面我们说到,随着num_updates的增大,decay的计算结果越来越接近1,那么1-decay就趋近于0,即,模型参数更新变得越来越慢。
这样,我们就可以通过对num_updates的控制来使得模型在训练初期参数更新幅度加大,在接近最优值处参数更新幅度减小,即在减少训练时间的基础上保证模型训练的精度。
如何使用滑动平均模型?
这里301给出在Tensorflow中一个简单的滑动模型的应用,并会给出详尽的注释!
import tensorflow as tf #定义一个变量用于计算滑动平均,初始值为 0 ,类型为实数 v1 = tf.Variable(0, dtype = tf.float32) #step变量用来模拟神经网络中迭代的轮数,即我们上面说的num_updates参数,用来动态控制衰减率 step = tf.Variable(0, trainable = False) #定义一个滑动平均类,初始化衰减率(0.99)和衰减率控制变量step #该函数返回一个ExponentialMovingAverage对象,该对象调用apply方法可以通过滑动平均模型来更新参数 ema = tf.train.ExponentialMovingAverage(0.99, step) #定义一个更新变量滑动平均的操作。 #这里的给定数据需要是列表的形式,每次执行这个操作时列表中的变量都会被更新 maintain_averages_op = ema.apply([v1]) with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) #通过ema.average(v1)获取滑动平均之后变量的取值,此处输出为[0.0 , 0.0] #初始化之后变量v1和v1的滑动平均都为0 print sess.run([v1, ema.average(v1)]) #更新变量v1的值为5 sess.run(tf.assign(v1, 5)) #更新v1的滑动平均值 #此时衰减率为min(0.99,(1+step)/(10+step)=0.1) = 0.1 #所以v1的滑动平均会被更新为0.1*0 + 0.9*5 = 4.5 sess.run(maintain_averages_op) print sess.run([v1, ema.average(v1)]) #输出[5.0 , 4.5] #更新step的值为10000 sess.run(tf.assign(step, 10000)) #更新v1的值为10 sess.run(tf.assign(v1, 10)) #计算v1的滑动平均值 #此时衰减率为min(0.99,(1+step)/(10+step)=0.999999) = 0.99 #所以v1的滑动平均会被更新为0.99*4.5 + 0.01*10 = 4.555 sess.run(maintain_averages_op) print sess.run([v1, ema.average(v1)]) #输出[10.0, 4.5549998] #再次更新滑动平均值,得到的新的滑动平均值为0.99*4.555 + 0.01*10 = 4.60945 sess.run(maintain_averages_op) print sess.run([v1, ema.average(v1)]) #输出[10.0, 4.6094499]