zoukankan      html  css  js  c++  java
  • tf.get_variable()和tf.Variable()的区别(最清晰的解释)

    欢迎关注WX公众号:【程序员管小亮】

    tf.Variable()用于生成一个初始值为initial-value的变量;必须指定初始化值。

    tf.get_variable()获取已存在的变量(要求不仅名字,而且初始化方法等各个参数都一样),如果不存在,就新建一个;可以用各种初始化方法,不用明确指定值。

    一、tf.Variable()

    tf.Variable(
    	initial_value=None, 
    	trainable=True, 
    	collections=None, 
    	validate_shape=True, 
    	caching_device=None, 
    	name=None, 
    	variable_def=None, 
    	dtype=None, 
    	expected_shape=None, 
    	import_scope=None
    )
    

    参数:

    • initial_valueTensor或可转换为Tensor的Python对象,它是Variable的初始值。除非validate_shape设置为False,否则初始值必须具有指定的形状;也可以是一个可调用的,没有参数,在调用时返回初始值。在这种情况下,必须指定dtype。 (请注意,init_ops.py中的初始化函数必须首先绑定到形状才能在此处使用。)

    • trainable如果为True,则会默认将变量添加到图形集合GraphKeys.TRAINABLE_VARIABLES中。此集合用于Optimizer类优化的的默认变量列表【可为optimizer指定其他的变量集合】,可就是要训练的变量列表。

    • collections:一个图graph集合列表的关键字。新变量将添加到这个集合中。默认为[GraphKeys.GLOBAL_VARIABLES]。也可自己指定其他的集合列表。

    • validate_shape:如果为False,则允许使用未知形状的值初始化变量。如果为True,则默认为initial_value的形状必须已知。

    • caching_device:可选设备字符串,描述应该缓存变量以供读取的位置。默认为Variable的设备。如果不是None,则在另一台设备上缓存。典型用法是在使用变量驻留的Ops的设备上进行缓存,以通过Switch和其他条件语句进行重复数据删除。

    • name变量的可选名称。默认为“Variable”并自动获取。

    • variable_def:VariableDef协议缓冲区。如果不是None,则使用其内容重新创建Variable对象,引用图中必须已存在的变量节点。图表未更改。variable_def和其他参数是互斥的。

    • dtype:如果设置,则initial_value将转换为给定类型。如果为None,则保留数据类型(如果initial_value是Tensor),或者convert_to_tensor将决定。

    • expected_shapeTensorShape。如果设置,则initial_value应具有此形状。

    • import_scope:可选字符串。要添加到变量的名称范围。仅在从协议缓冲区初始化时使用。

    一般常用的参数包括初始化值和名称name(是该变量的唯一索引),在使用变量之前必须要进行初始化,初始化的方式有三种:

    1. 在会话中运行initializer操作。
    2. 从文件中恢复,如restore from checkpoint
    3. 自己通过tf.assign()给变量附初值。

    二、tf.get_variable()

    get_variable(
        name,
        shape=None,
        dtype=None,
        initializer=None,
        regularizer=None,
        trainable=True,
        collections=None,
        caching_device=None,
        partitioner=None,
        validate_shape=True,
        use_resource=None,
        custom_getter=None,
        constraint=None
    )
    

    参数:

    • name新变量或现有变量的名称。

    • shape:新变量或现有变量的形状。

    • dtype:新变量或现有变量的类型(默认为DT_FLOAT)。

    • ininializer如果创建了,则用它来初始化变量。

    • regularizerA(Tensor - > Tensor或None)函数;将它应用于新创建的变量的结果将添加到集合tf.GraphKeys.REGULARIZATION_LOSSES中,并可用于正则化。

    • trainable如果为True,还将变量添加到图形集合GraphKeys.TRAINABLE_VARIABLES(参见tf.Variable)。

    • collections:要将变量添加到的图表集合列表。默认为[GraphKeys.GLOBAL_VARIABLES](参见tf.Variable)。

    • caching_device:可选的设备字符串或函数,描述变量应被缓存以供读取的位置。默认为Variable的设备。如果不是None,则在另一台设备上缓存。典型用法是在使用变量驻留的Ops的设备上进行缓存,以通过Switch和其他条件语句进行重复数据删除。

    • partitioner:可选callable,接受完全定义的TensorShape和要创建的Variable的dtype,并返回每个轴的分区列表(当前只能对一个轴进行分区)。

    • validate_shape:如果为False,则允许使用未知形状的值初始化变量。如果为True,则默认为initial_value的形状必须已知。

    • use_resource:如果为False,则创建常规变量。如果为true,则使用定义良好的语义创建实验性ResourceVariable。默认为False(稍后将更改为True)。在Eager模式下,此参数始终强制为True。

    • custom_getter:Callable,它将第一个参数作为true getter,并允许覆盖内部get_variable方法。 custom_getter的签名应与此方法的签名相匹配,但最适合未来的版本将允许更改:def custom_getter(getter,* args,** kwargs)。也允许直接访问所有get_variable参数:def custom_getter(getter,name,* args,** kwargs)。一个简单的身份自定义getter只需创建具有修改名称的变量是:python def custom_getter(getter,name,* args,** kwargs):return getter(name +'_suffix',* args,** kwargs)

    如果initializer初始化方法是None(默认值),则会使用variable_scope()中定义的initializer,如果也为None,则默认使用glorot_uniform_initializer,也可以使用其他的tensor来初始化,value、和shape与此tensor相同。

    正则化方法默认是None,如果不指定,只会使用variable_scope()中的正则化方式,如果也为None,则不使用正则化;

    三、区别

    推荐使用tf.get_variable(), 因为:

    1. 初始化更方便

    比如用xavier_initializer:

    W = tf.get_variable("W", shape=[784, 256], initializer=tf.contrib.layers.xavier_initializer())
    
    1. 方便共享变量

    因为tf.get_variable()会检查当前命名空间下是否存在同样name的变量,可以方便共享变量。而tf.Variable每次都会新建一个变量。

    需要注意的是tf.get_variable(),要配合reusetf.variable_scope()使用,对于get_variable()来说,如果已经创建的变量对象,就把那个对象返回,如果没有创建变量对象的话,就创建一个新的。

    例子1:

    import tensorflow as tf
    w_1 = tf.Variable(3,name="w_1")
    w_2 = tf.Variable(1,name="w_1")
    print(w_1.name)
    print(w_2.name)
    # 输出
    # w_1:0
    # w_1_1:0
    
    import tensorflow as tf
    w_1 = tf.get_variable(name="w_1",initializer=1)
    w_2 = tf.get_variable(name="w_1",initializer=2)
    # 错误信息
    # ValueError: Variable w_1 already exists, disallowed. 
    # Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? 
    

    例子2:

    import tensorflow as tf
    
    with tf.variable_scope("scope1"):
        w1 = tf.get_variable("w1", shape=[])
        w2 = tf.Variable(0.0, name="w2")
    with tf.variable_scope("scope1", reuse=True):
        w1_p = tf.get_variable("w1", shape=[])
        w2_p = tf.Variable(1.0, name="w2")
    
    print(w1 is w1_p, w2 is w2_p)
    #输出
    #True  False
    

    四、实例

    import tensorflow as tf
    
    with tf.variable_scope("one"):
        a = tf.get_variable("v", [1]) #a.name == "one/v:0"
    with tf.variable_scope("one"):
        b = tf.get_variable("v", [1]) #创建两个名字一样的变量会报错 ValueError: Variable one/v already exists 
    with tf.variable_scope("one", reuse = True): #注意reuse的作用。
        c = tf.get_variable("v", [1]) #c.name == "one/v:0" 成功共享,因为设置了reuse
    
    assert a==c #Assertion is true, they refer to the same object.
    
    with tf.variable_scope("two"):
        d = tf.get_variable("v", [1]) #d.name == "two/v:0"
        e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"  
    
    assert d==e #AssertionError: they are different objects
    

    python课程推荐。
    在这里插入图片描述

  • 相关阅读:
    CAS 认证
    最近邻规则分类(k-Nearest Neighbor )机器学习算法python实现
    scikit-learn决策树的python实现以及作图
    module object has no attribute dumps的解决方法
    最新Flume1.7 自定义 MongodbSink 结合TAILDIR Sources的使用
    数据探索中的贡献度分析
    python logging模块按天滚动简单程序
    Flume性能测试报告(翻译Flume官方wiki报告)
    python apsheduler cron 参数解析
    python pyspark入门篇
  • 原文地址:https://www.cnblogs.com/hzcya1995/p/13302805.html
Copyright © 2011-2022 走看看