欢迎关注WX公众号:【程序员管小亮】
tf.Variable()用于生成一个初始值为initial-value的变量;必须指定初始化值。
tf.get_variable()获取已存在的变量(要求不仅名字,而且初始化方法等各个参数都一样),如果不存在,就新建一个;可以用各种初始化方法,不用明确指定值。
一、tf.Variable()
tf.Variable(
initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None
)
参数:
-
initial_value:Tensor或可转换为Tensor的Python对象,它是Variable的初始值。除非validate_shape设置为False,否则初始值必须具有指定的形状;也可以是一个可调用的,没有参数,在调用时返回初始值。在这种情况下,必须指定dtype。 (请注意,init_ops.py中的初始化函数必须首先绑定到形状才能在此处使用。) -
trainable:如果为True,则会默认将变量添加到图形集合GraphKeys.TRAINABLE_VARIABLES中。此集合用于Optimizer类优化的的默认变量列表【可为optimizer指定其他的变量集合】,可就是要训练的变量列表。 -
collections:一个图graph集合列表的关键字。新变量将添加到这个集合中。默认为[GraphKeys.GLOBAL_VARIABLES]。也可自己指定其他的集合列表。 -
validate_shape:如果为False,则允许使用未知形状的值初始化变量。如果为True,则默认为initial_value的形状必须已知。 -
caching_device:可选设备字符串,描述应该缓存变量以供读取的位置。默认为Variable的设备。如果不是None,则在另一台设备上缓存。典型用法是在使用变量驻留的Ops的设备上进行缓存,以通过Switch和其他条件语句进行重复数据删除。 -
name:变量的可选名称。默认为“Variable”并自动获取。 -
variable_def:VariableDef协议缓冲区。如果不是None,则使用其内容重新创建Variable对象,引用图中必须已存在的变量节点。图表未更改。variable_def和其他参数是互斥的。 -
dtype:如果设置,则initial_value将转换为给定类型。如果为None,则保留数据类型(如果initial_value是Tensor),或者convert_to_tensor将决定。 -
expected_shape:TensorShape。如果设置,则initial_value应具有此形状。 -
import_scope:可选字符串。要添加到变量的名称范围。仅在从协议缓冲区初始化时使用。
一般常用的参数包括初始化值和名称name(是该变量的唯一索引),在使用变量之前必须要进行初始化,初始化的方式有三种:
- 在会话中运行
initializer操作。 - 从文件中恢复,如
restore from checkpoint。 - 自己通过
tf.assign()给变量附初值。
二、tf.get_variable()
get_variable(
name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
constraint=None
)
参数:
-
name:新变量或现有变量的名称。 -
shape:新变量或现有变量的形状。 -
dtype:新变量或现有变量的类型(默认为DT_FLOAT)。 -
ininializer:如果创建了,则用它来初始化变量。 -
regularizer:A(Tensor - > Tensor或None)函数;将它应用于新创建的变量的结果将添加到集合tf.GraphKeys.REGULARIZATION_LOSSES中,并可用于正则化。 -
trainable:如果为True,还将变量添加到图形集合GraphKeys.TRAINABLE_VARIABLES(参见tf.Variable)。 -
collections:要将变量添加到的图表集合列表。默认为[GraphKeys.GLOBAL_VARIABLES](参见tf.Variable)。 -
caching_device:可选的设备字符串或函数,描述变量应被缓存以供读取的位置。默认为Variable的设备。如果不是None,则在另一台设备上缓存。典型用法是在使用变量驻留的Ops的设备上进行缓存,以通过Switch和其他条件语句进行重复数据删除。 -
partitioner:可选callable,接受完全定义的TensorShape和要创建的Variable的dtype,并返回每个轴的分区列表(当前只能对一个轴进行分区)。 -
validate_shape:如果为False,则允许使用未知形状的值初始化变量。如果为True,则默认为initial_value的形状必须已知。 -
use_resource:如果为False,则创建常规变量。如果为true,则使用定义良好的语义创建实验性ResourceVariable。默认为False(稍后将更改为True)。在Eager模式下,此参数始终强制为True。 -
custom_getter:Callable,它将第一个参数作为true getter,并允许覆盖内部get_variable方法。custom_getter的签名应与此方法的签名相匹配,但最适合未来的版本将允许更改:def custom_getter(getter,* args,** kwargs)。也允许直接访问所有get_variable参数:def custom_getter(getter,name,* args,** kwargs)。一个简单的身份自定义getter只需创建具有修改名称的变量是:python def custom_getter(getter,name,* args,** kwargs):return getter(name +'_suffix',* args,** kwargs)。
如果initializer初始化方法是None(默认值),则会使用variable_scope()中定义的initializer,如果也为None,则默认使用glorot_uniform_initializer,也可以使用其他的tensor来初始化,value、和shape与此tensor相同。
正则化方法默认是None,如果不指定,只会使用variable_scope()中的正则化方式,如果也为None,则不使用正则化;
三、区别
推荐使用tf.get_variable(), 因为:
- 初始化更方便
比如用xavier_initializer:
W = tf.get_variable("W", shape=[784, 256], initializer=tf.contrib.layers.xavier_initializer())
- 方便共享变量
因为tf.get_variable()会检查当前命名空间下是否存在同样name的变量,可以方便共享变量。而tf.Variable每次都会新建一个变量。
需要注意的是tf.get_variable(),要配合reuse和tf.variable_scope()使用,对于get_variable()来说,如果已经创建的变量对象,就把那个对象返回,如果没有创建变量对象的话,就创建一个新的。
例子1:
import tensorflow as tf
w_1 = tf.Variable(3,name="w_1")
w_2 = tf.Variable(1,name="w_1")
print(w_1.name)
print(w_2.name)
# 输出
# w_1:0
# w_1_1:0
import tensorflow as tf
w_1 = tf.get_variable(name="w_1",initializer=1)
w_2 = tf.get_variable(name="w_1",initializer=2)
# 错误信息
# ValueError: Variable w_1 already exists, disallowed.
# Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?
例子2:
import tensorflow as tf
with tf.variable_scope("scope1"):
w1 = tf.get_variable("w1", shape=[])
w2 = tf.Variable(0.0, name="w2")
with tf.variable_scope("scope1", reuse=True):
w1_p = tf.get_variable("w1", shape=[])
w2_p = tf.Variable(1.0, name="w2")
print(w1 is w1_p, w2 is w2_p)
#输出
#True False
四、实例
import tensorflow as tf
with tf.variable_scope("one"):
a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
b = tf.get_variable("v", [1]) #创建两个名字一样的变量会报错 ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True): #注意reuse的作用。
c = tf.get_variable("v", [1]) #c.name == "one/v:0" 成功共享,因为设置了reuse
assert a==c #Assertion is true, they refer to the same object.
with tf.variable_scope("two"):
d = tf.get_variable("v", [1]) #d.name == "two/v:0"
e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"
assert d==e #AssertionError: they are different objects
python课程推荐。
