1、tf.ConfigProto
tf.ConfigProto一般用在创建session的时候,用来对session进行参数配置:
with tf.Session(config=tf.ConfigProto(),...): # tf.ConfigProto()的参数 log_device_placement=True #是否打印设备分配日志 allow_soft_placement=True #如果你指定的设备不存在,允许TF自动分配设备 tf.Session(config=tf.ConfigProto(allow_soft_placement=True,log_device_placement=True))
2、tf.add_to_collection ,tf.add_n 的用法:
前者用于收集collection,后者将收集好的collection进行汇总
import tensorflow as tf a=tf.get_variable('a',shape=[1],initializer=tf.constant_initializer([1])) tf.add_to_collection('loss',a) b=tf.get_variable('b',shape=[1],initializer=tf.constant_initializer([1])) tf.add_to_collection('loss',b) #通过tf.add_to_collection进行收集 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.get_collection('loss')) #从一个集合中取出全部变量,结果是一个列表 print(sess.run(tf.add_n(tf.get_collection('loss')))) #将列表取出的内容进行累加