欢迎关注WX公众号:【程序员管小亮】
这两天看batch normalization的代码时,学到滑动平均窗口函数ExponentialMovingAverage时,碰到一个函数tf.identity()函数,特此记录。
tf.identity()函数用于返回一个和input一样的新的tensor。
tf.identity(
input,
name=None
)
#Return a tensor with the same shape and contents as input.
#返回一个tensor,contents和shape都和input的一样
简单来说,就是返回一个和input一样的新的tensor。
例子1:
import tensorflow as tf
w = tf.Variable(1.0)
ema = tf.train.ExponentialMovingAverage(0.9)
update = tf.assign_add(w, 1.0)
ema_op = ema.apply([update])
with tf.control_dependencies([ema_op]):
ema_val = ema.average(update)
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(3):
print(sess.run([ema_val]))
> [0.0]
> [0.0]
> [0.0]
理想的情况下,在我们 sess.run([ema_val]), ema_op 都会被先执行,然后再计算ema_val,实际情况并不是这样,为什么?
有兴趣的可以看一下源码,就会发现 ema.average(update) 不是一个 op,它只是从ema对象的一个字典中取出键对应的 tensor而已,然后赋值给ema_val。这个 tensor是由一个在 tf.control_dependencies([ema_op]) 外部的一个 op 计算得来的,所以control_dependencies会失效。解决方法也很简单,看代码:
import tensorflow as tf
w = tf.Variable(1.0)
ema = tf.train.ExponentialMovingAverage(0.9)
update = tf.assign_add(w, 1.0)
ema_op = ema.apply([update])
with tf.control_dependencies([ema_op]):
ema_val = tf.identity(ema.average(update)) #一个identity搞定
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(3):
print(sess.run([ema_val]))
> [0.20000005]
> [0.4800001]
> [0.8320002]
例子2:
import tensorflow as tf
x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
y = x
init = tf.global_variables_initializer()
with tf.Session() as session:
init.run() # 相当于session.run(init)
for i in range(5):
print(y.eval()) # y.eval()这个相当于session.run(y)
> 0.0
0.0
0.0
0.0
0.0
理想的情况下,输出应该是:[1.0, 2.0, 3.0, 4.0, 5.0],实际情况并不是这样,为什么?
1 tf.control_dependencies()是一个在Graph上的operation,所以要想使得其参数起作用,就需要for循环里面利用sess.run()来执行;
2 y = x只是一个简单的赋值操作,而with tf.control_dependencies()作用域(也就是冒号下的代码行)只对op起作用,所以需要将tensor利用tf.identity()来转化为op。
针对以上原因,给出两个相应的解决方法:
1.
import tensorflow as tf
x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
y = x
init = tf.global_variables_initializer()
with tf.Session() as sess:
init.run() # 相当于session.run(init)
for i in range(5):
sess.run(x_plus_1)
print(y.eval()) # y.eval()这个相当于session.run(y)
> 1.0
2.0
3.0
4.0
5.0
2.
import tensorflow as tf
x = tf.Variable(0.0)
x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
with tf.control_dependencies([x_plus_1]):
y = tf.identity(x)
init = tf.global_variables_initializer()
with tf.Session() as sess:
init.run() # 相当于session.run(init)
for i in range(5):
print(y.eval()) # y.eval()这个相当于session.run(y)
> 1.0
2.0
3.0
4.0
5.0
Graph上不论是tensor还是operation的更新都要借助op来进行,而将一个tensor转化为op最简单的方法就是tf.identity()。
python课程推荐。
参考文章:
tensorflow学习笔记(四十一):control dependencies
tf.control_dependencies()和tf.identity()