zoukankan      html  css  js  c++  java
  • TensorFlow的变量管理:变量作用域机制

    在深度学习中,你可能需要用到大量的变量集,而且这些变量集可能在多处都要用到。例如,训练模型时,训练参数如权重(weights)、偏置(biases)等已经定下来,要拿到验证集去验证,我们自然希望这些参数是同一组。以往写简单的程序,可能使用全局限量就可以了,但在深度学习中,这显然是不行的,一方面不便管理,另外这样一来代码的封装性受到极大影响。因此,TensorFlow提供了一种变量管理方法:变量作用域机制,以此解决上面出现的问题。

    TensorFlow的变量作用域机制依赖于以下两个方法,官方文档中定义如下:

    [plain] view plain copy
     
    1. tf.get_variable(name, shape, initializer): Creates or returns a variable with a given name.建立或返回一个给定名称的变量  
    2. tf.variable_scope( scope_name): Manages namespaces for names passed to tf.get_variable(). 管理传递给tf.get_variable()的变量名组成的命名空间  

    先说说tf.get_variable(),这个方法在建立新的变量时与tf.Variable()完全相同。它的特殊之处在于,他还会搜索是否有同名的变量。创建变量用法如下:

    [plain] view plain copy
     
    1. with tf.variable_scope("foo"):  
    2.     with tf.variable_scope("bar"):  
    3.         v = tf.get_variable("v", [1])  
    4.         assert v.name == "foo/bar/v:0"  


    而tf.variable_scope(scope_name),它会管理在名为scope_name的域(scope)下传递给tf.get_variable的所有变量名(组成了一个变量空间),根据规则确定这些变量是否进行复用。这个方法最重要的参数是reuse,有None,tf.AUTO_REUSE与True三个选项。具体用法如下:

    1. reuse的默认选项是None,此时会继承父scope的reuse标志。
    2. 自动复用(设置reuse为tf.AUTO_REUSE),如果变量存在则复用,不存在则创建。这是最安全的用法,在使用新推出的EagerMode时reuse将被强制为tf.AUTO_REUSE选项。用法如下:
      [plain] view plain copy
       
      1. def foo():  
      2.   with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):  
      3.     v = tf.get_variable("v", [1])  
      4.   return v  
      5.   
      6. v1 = foo()  # Creates v.  
      7. v2 = foo()  # Gets the same, existing v.  
      8. assert v1 == v2  
    3. 复用(设置reuse=True):
      [plain] view plain copy
       
      1. with tf.variable_scope("foo"):  
      2.   v = tf.get_variable("v", [1])  
      3. with tf.variable_scope("foo", reuse=True):  
      4.   v1 = tf.get_variable("v", [1])  
      5. assert v1 == v  
    4. 捕获某一域并设置复用(scope.reuse_variables()):
      [plain] view plain copy
       
      1. with tf.variable_scope("foo") as scope:  
      2.   v = tf.get_variable("v", [1])  
      3.   scope.reuse_variables()  
      4.   v1 = tf.get_variable("v", [1])  
      5. assert v1 == v  

      1)非复用的scope下再次定义已存在的变量;或2)定义了复用但无法找到已定义的变量,TensorFlow都会抛出错误,具体如下:
    [plain] view plain copy
     
    1. with tf.variable_scope("foo"):  
    2.     v = tf.get_variable("v", [1])  
    3.     v1 = tf.get_variable("v", [1])  
    4.     #  Raises ValueError("... v already exists ...").  
    5.   
    6.   
    7. with tf.variable_scope("foo", reuse=True):  
    8.     v = tf.get_variable("v", [1])  
    9.     #  Raises ValueError("... v does not exists ...").  
     
    转自: https://blog.csdn.net/zbgjhy88/article/details/78960388
  • 相关阅读:
    Eclipse的常见使用错误及编译错误
    Android学习笔记之Bundle
    Android牟利之道(二)广告平台的介绍
    Perl dbmopen()函数
    Perl子例程(函数)
    Perl内置操作符
    Perl正则表达式
    Linux之间配置SSH互信(SSH免密码登录)
    思科路由器NAT配置详解(转)
    Windows下查看端口被程序占用的方法
  • 原文地址:https://www.cnblogs.com/pzf9266/p/9012296.html
Copyright © 2011-2022 走看看