tf.Variable
__init__( initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, variable_def=None, dtype=None, expected_shape=None, import_scope=None )
功能说明:
维护图在执行过程中的状态信息,例如神经网络权重值的变化。
参数列表:
参数名 | 类型 | 说明 |
---|---|---|
initial_value | 张量 | Variable 类的初始值,这个变量必须指定 shape 信息,否则后面 validate_shape 需设为 False |
trainable | Boolean | 是否把变量添加到 collection GraphKeys.TRAINABLE_VARIABLES 中(collection 是一种全局存储,不受变量名生存空间影响,一处保存,到处可取) |
collections | Graph collections | 全局存储,默认是 GraphKeys.GLOBAL_VARIABLES |
validate_shape | Boolean | 是否允许被未知维度的 initial_value 初始化 |
caching_device | string | 指明哪个 device 用来缓存变量 |
name | string | 变量名 |
dtype | dtype | 如果被设置,初始化的值就会按照这个类型初始化 |
expected_shape | TensorShape | 要是设置了,那么初始的值会是这种维度 |
示例代码:
import tensorflow as tf initial= tf.truncated_normal(shape=[10,10],mean=0,stddev=1) W=tf.Variable(initial) list=[[1.,1.],[2.,2.]] X=tf.Variable(list,dtype=tf.float32) ini_op=tf.global_variables_initializer() with tf.Session() as sess: sess.run(ini_op) print(sess.run(W[:2,:2])) op=W[:2,:2].assign(22.*tf.ones((2,2))) print(sess.run(op)) print (W.eval()) #Usage with the default session print ("#####################(6)#############") print (W.dtype) print (sess.run(W.initial_value)) print (sess.run(W.op)) print (W.shape) print ("###################(7)###############") print (sess.run(X))