zoukankan      html  css  js  c++  java
  • tf.identity()函数解析(最清晰的解释)

    欢迎关注WX公众号:【程序员管小亮】

    这两天看batch normalization的代码时,学到滑动平均窗口函数ExponentialMovingAverage时,碰到一个函数tf.identity()函数,特此记录。

    tf.identity()函数用于返回一个和input一样的新的tensor。

    tf.identity(
    	input,
    	name=None
    )
    #Return a tensor with the same shape and contents as input.
    #返回一个tensor,contents和shape都和input的一样
    

    简单来说,就是返回一个和input一样的新的tensor。

    例子1:

    import tensorflow as tf
    w = tf.Variable(1.0)
    ema = tf.train.ExponentialMovingAverage(0.9)
    update = tf.assign_add(w, 1.0)
    
    ema_op = ema.apply([update])
    with tf.control_dependencies([ema_op]):
        ema_val = ema.average(update)
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        for i in range(3):
            print(sess.run([ema_val]))
    
    > [0.0]
    > [0.0]
    > [0.0]
    

    理想的情况下,在我们 sess.run([ema_val]), ema_op 都会被先执行,然后再计算ema_val,实际情况并不是这样,为什么?

    有兴趣的可以看一下源码,就会发现 ema.average(update) 不是一个 op,它只是从ema对象的一个字典中取出键对应的 tensor而已,然后赋值给ema_val。这个 tensor是由一个在 tf.control_dependencies([ema_op]) 外部的一个 op 计算得来的,所以control_dependencies会失效。解决方法也很简单,看代码:

    import tensorflow as tf
    w = tf.Variable(1.0)
    ema = tf.train.ExponentialMovingAverage(0.9)
    update = tf.assign_add(w, 1.0)
    
    ema_op = ema.apply([update])
    with tf.control_dependencies([ema_op]):
        ema_val = tf.identity(ema.average(update)) #一个identity搞定
    
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        for i in range(3):
            print(sess.run([ema_val]))
    
    > [0.20000005]
    > [0.4800001]
    > [0.8320002]
    

    例子2:

    import tensorflow as tf
    
    x = tf.Variable(0.0)
    x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
    with tf.control_dependencies([x_plus_1]):
        y = x
    init = tf.global_variables_initializer()
    with tf.Session() as session:
        init.run() # 相当于session.run(init)
        for i in range(5):
            print(y.eval()) # y.eval()这个相当于session.run(y)
    
    > 0.0
      0.0
      0.0
      0.0
      0.0
    

    理想的情况下,输出应该是:[1.0, 2.0, 3.0, 4.0, 5.0],实际情况并不是这样,为什么?

    1 tf.control_dependencies()是一个在Graph上的operation,所以要想使得其参数起作用,就需要for循环里面利用sess.run()来执行;

    2 y = x只是一个简单的赋值操作,而with tf.control_dependencies()作用域(也就是冒号下的代码行)只对op起作用,所以需要将tensor利用tf.identity()来转化为op。

    针对以上原因,给出两个相应的解决方法:

    1.
    import tensorflow as tf
    
    x = tf.Variable(0.0)
    x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
    with tf.control_dependencies([x_plus_1]):
        y = x
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        init.run() # 相当于session.run(init)
        for i in range(5):
            sess.run(x_plus_1)
            print(y.eval()) # y.eval()这个相当于session.run(y)
    
    > 1.0
       2.0
       3.0
       4.0
       5.0
    
    2.
    import tensorflow as tf
    
    x = tf.Variable(0.0)
    x_plus_1 = tf.assign_add(x, 1) # 对x进行加1,x_plus_l是个op
    with tf.control_dependencies([x_plus_1]):
        y = tf.identity(x)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        init.run() # 相当于session.run(init)
        for i in range(5):
            print(y.eval()) # y.eval()这个相当于session.run(y)
    
    > 1.0
       2.0
       3.0
       4.0
       5.0
    

    Graph上不论是tensor还是operation的更新都要借助op来进行,而将一个tensor转化为op最简单的方法就是tf.identity()。

    python课程推荐。
    在这里插入图片描述

    参考文章:

    tensorflow学习笔记(四十一):control dependencies
    tf.control_dependencies()和tf.identity()

  • 相关阅读:
    SpringBlade 端口占用 Web server failed to start. Port 80 was already in use.
    SpringBlade 找不到或无法加载主类 springboot.Application
    Java idea 常用快捷键
    Java Velocity
    个人 一些需求
    Java MyBatis-Plus 基本使用
    Java Spring Initializr 创建的项目 包是一层一层的,需要隐藏一下空包
    MapReduce之自定义OutputFormat
    数据链路层之PPP协议
    MapReduce之GroupingComparator分组(辅助排序、二次排序)
  • 原文地址:https://www.cnblogs.com/hzcya1995/p/13302847.html
Copyright © 2011-2022 走看看