zoukankan      html  css  js  c++  java
  • tf.add_to_collection 和 tf.get_collection 和 tf.add_n

    简要介绍

    tf.add_to_collection:把多个变量放在一个 自己命名 的集合里,包括不同域内的变量

    tf.get_collection:读取一个列表,生成一个新列表

    tf.add_n:把一个列表里的元素求和

    add_to_collection(name, value)

    name 为集合名,value 为 变量;

    通常 tensorflow 会把变量和可训练的变量自动收集起来,包括不同域的变量;

    变量对应的集合名字叫 variables,或者叫 tf.GraphKeys.VARIABLES;

    可训练的变量对应的集合名字为 trainable_variables,或者叫 tf.GraphKeys.TRAINABLE_VARIABLES;

    print(tf.GraphKeys.VARIABLES)           # variables
    print(tf.GraphKeys.TRAINABLE_VARIABLES) # trainable_variables

    示例

    with tf.name_scope('test1') as test1:
        v1 = tf.Variable(1)
        tf.add_to_collection('all', v1)     ### 显式加入集合
    
    with tf.name_scope('test2') as test2:
        v2 = tf.Variable(2)
        tf.add_to_collection('all', v2)     ### 显式加入集合
    
    for i in tf.get_collection(tf.GraphKeys.VARIABLES):     ### tf 自动收集
        print(i)
    # <tf.Variable 'test1/Variable:0' shape=() dtype=int32_ref>
    # <tf.Variable 'test2/Variable:0' shape=() dtype=int32_ref>
    
    for j in tf.get_collection('all'):
        print(j)
    # <tf.Variable 'test1/Variable:0' shape=() dtype=int32_ref>
    # <tf.Variable 'test2/Variable:0' shape=() dtype=int32_ref>
    
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        print(sess.run(tf.add_n(tf.get_collection('all')))) # 3

    它的作用是不停地记录关注变量,然后求和

    d1 = tf.Variable(1)
    d2 = tf.Variable(2)
    d3 = tf.Variable(3)
    tf.add_to_collection('sum', d1)
    tf.add_to_collection('sum', d2)
    tf.add_to_collection('sum', d3)
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        print(sess.run(tf.add_n(tf.get_collection('sum'))))     # 6

    get_collection(key, scope=None)

    key 集合名,scope 作用域

    示例

    v1 = tf.Variable(1, name='v1')
    v2 = tf.get_variable(name='v2', initializer=2)
    v3 = tf.Variable(3, name='v3', trainable=False)
    
    print(tf.get_variable_scope().name)       #
    print(tf.GraphKeys.VARIABLES)           # variables
    print(tf.GraphKeys.TRAINABLE_VARIABLES) # trainable_variables
    
    ### 获取全部变量 key=variables,scope=None
    for j in tf.get_collection(tf.GraphKeys.VARIABLES):
        print(j)
    # <tf.Variable 'v1:0' shape=() dtype=int32_ref>
    # <tf.Variable 'v2:0' shape=() dtype=int32_ref>
    # <tf.Variable 'v3:0' shape=() dtype=int32_ref>
    
    ### 获取全部可训练变量 key=trainable_variables,scope=None
    for k in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
        print(k)
    # <tf.Variable 'v1:0' shape=() dtype=int32_ref>
    # <tf.Variable 'v2:0' shape=() dtype=int32_ref>
    
    ### 获取全部可训练变量 key=trainable_variables,scope=None,等价于上个操作
    for m in tf.get_collection('trainable_variables'):
        print(m)
    # <tf.Variable 'v1:0' shape=() dtype=int32_ref>
    # <tf.Variable 'v2:0' shape=() dtype=int32_ref>

    示例2:指定作用域,接上例

    ### 增加一个作用域
    with tf.name_scope('test') as test:
        v4 = tf.Variable(4, name='v4')
        v5 = tf.Variable(5, name='v5', trainable=False)
    
    ### 获取全部可训练变量 key=trainable_variables,scope=None,包括新的作用域
    for s in tf.get_collection('trainable_variables'):
        print(s)
    # <tf.Variable 'v1:0' shape=() dtype=int32_ref>
    # <tf.Variable 'v2:0' shape=() dtype=int32_ref>
    # <tf.Variable 'test/v4:0' shape=() dtype=int32_ref>
    
    ### 获取指定作用域下的可训练变量 key=trainable_variables,scope=test
    for t in tf.get_collection('trainable_variables', test):
        print(t)
    # <tf.Variable 'test/v4:0' shape=() dtype=int32_ref>

    add_n(inputs, name=None)

    很简单了,上面的例子中有用到

    参考资料:

    https://blog.csdn.net/uestc_c2_403/article/details/72415791

    https://blog.csdn.net/nini_coded/article/details/80528466

  • 相关阅读:
    on duplicate key update之多列唯一索引
    js 判断 微信浏览器 安卓/苹果 pc/移动
    history 和 hash (转)
    路由vue-router
    添加图标ico
    vue项目结构
    vue2.0项目的构建
    echarts使用 图例改变和默认不选中
    微信自定义菜单设置 及 emoji表情更换
    复制/设置剪切板内容 (浏览器/nativejs)
  • 原文地址:https://www.cnblogs.com/yanshw/p/12435071.html
Copyright © 2011-2022 走看看