简要介绍
tf.add_to_collection:把多个变量放在一个 自己命名 的集合里,包括不同域内的变量
tf.get_collection:读取一个列表,生成一个新列表
tf.add_n:把一个列表里的元素求和
add_to_collection(name, value)
name 为集合名,value 为 变量;
通常 tensorflow 会把变量和可训练的变量自动收集起来,包括不同域的变量;
变量对应的集合名字叫 variables,或者叫 tf.GraphKeys.VARIABLES;
可训练的变量对应的集合名字为 trainable_variables,或者叫 tf.GraphKeys.TRAINABLE_VARIABLES;
print(tf.GraphKeys.VARIABLES) # variables print(tf.GraphKeys.TRAINABLE_VARIABLES) # trainable_variables
示例
with tf.name_scope('test1') as test1: v1 = tf.Variable(1) tf.add_to_collection('all', v1) ### 显式加入集合 with tf.name_scope('test2') as test2: v2 = tf.Variable(2) tf.add_to_collection('all', v2) ### 显式加入集合 for i in tf.get_collection(tf.GraphKeys.VARIABLES): ### tf 自动收集 print(i) # <tf.Variable 'test1/Variable:0' shape=() dtype=int32_ref> # <tf.Variable 'test2/Variable:0' shape=() dtype=int32_ref> for j in tf.get_collection('all'): print(j) # <tf.Variable 'test1/Variable:0' shape=() dtype=int32_ref> # <tf.Variable 'test2/Variable:0' shape=() dtype=int32_ref> init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print(sess.run(tf.add_n(tf.get_collection('all')))) # 3
它的作用是不停地记录关注变量,然后求和
d1 = tf.Variable(1) d2 = tf.Variable(2) d3 = tf.Variable(3) tf.add_to_collection('sum', d1) tf.add_to_collection('sum', d2) tf.add_to_collection('sum', d3) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print(sess.run(tf.add_n(tf.get_collection('sum')))) # 6
get_collection(key, scope=None)
key 集合名,scope 作用域
示例
v1 = tf.Variable(1, name='v1') v2 = tf.get_variable(name='v2', initializer=2) v3 = tf.Variable(3, name='v3', trainable=False) print(tf.get_variable_scope().name) # 空 print(tf.GraphKeys.VARIABLES) # variables print(tf.GraphKeys.TRAINABLE_VARIABLES) # trainable_variables ### 获取全部变量 key=variables,scope=None for j in tf.get_collection(tf.GraphKeys.VARIABLES): print(j) # <tf.Variable 'v1:0' shape=() dtype=int32_ref> # <tf.Variable 'v2:0' shape=() dtype=int32_ref> # <tf.Variable 'v3:0' shape=() dtype=int32_ref> ### 获取全部可训练变量 key=trainable_variables,scope=None for k in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): print(k) # <tf.Variable 'v1:0' shape=() dtype=int32_ref> # <tf.Variable 'v2:0' shape=() dtype=int32_ref> ### 获取全部可训练变量 key=trainable_variables,scope=None,等价于上个操作 for m in tf.get_collection('trainable_variables'): print(m) # <tf.Variable 'v1:0' shape=() dtype=int32_ref> # <tf.Variable 'v2:0' shape=() dtype=int32_ref>
示例2:指定作用域,接上例
### 增加一个作用域 with tf.name_scope('test') as test: v4 = tf.Variable(4, name='v4') v5 = tf.Variable(5, name='v5', trainable=False) ### 获取全部可训练变量 key=trainable_variables,scope=None,包括新的作用域 for s in tf.get_collection('trainable_variables'): print(s) # <tf.Variable 'v1:0' shape=() dtype=int32_ref> # <tf.Variable 'v2:0' shape=() dtype=int32_ref> # <tf.Variable 'test/v4:0' shape=() dtype=int32_ref> ### 获取指定作用域下的可训练变量 key=trainable_variables,scope=test for t in tf.get_collection('trainable_variables', test): print(t) # <tf.Variable 'test/v4:0' shape=() dtype=int32_ref>
add_n(inputs, name=None)
很简单了,上面的例子中有用到
参考资料:
https://blog.csdn.net/uestc_c2_403/article/details/72415791
https://blog.csdn.net/nini_coded/article/details/80528466