zoukankan      html  css  js  c++  java
  • Tensorflow中的name_scope和variable_scope

    Tensorflow是一个编程模型,几乎成为了一种编程语言(里面有变量、有操作......)。
    Tensorflow编程分为两个阶段:构图阶段+运行时。
    Tensorflow构图阶段其实就是在对图进行一些描述性语言,跟html很像,很适合用标记性语言来描述。
    Tensorflow是有向图,是一个有向无环图。张量为边,操作为点,数据在图中流动。
    Tensorflow为每个结点都起了唯一的一个名字。

    import tensorflow as tf
    
    a = tf.constant(3)  # name=Const:0
    b = tf.Variable(4)  # name=Variable:0
    print(a.name, b.name)
    

    如上所示,即便你没有指明变量的name属性,tensorflow也会给它起个默认名字。

    在C++中有namespace的概念,命名空间的好处就是减少了命名冲突,我们可以在命名空间中使用较简易的标识符。为了便于用户定义变量的name,tensorflow也提出了name_scope

    with tf.name_scope("my"):
        a = tf.constant(3)  # my/Const:0
        b = tf.add(a, b)  # my/Add:0
        print(a.name, b.name)
        # 使用get_variable却不管用
        c = tf.get_variable("c", shape=1, dtype=tf.int32, initializer=tf.constant_initializer(2))  # c:0
        print(c.name)
    

    如上所示,在name_scope中的属性,会用类似文件路径的方式来定义变量的name属性
    但是关于name_scope需要明白两点:

    • 使用tf.get_variable函数创建的变量不会受name_scope的影响
    • 只有新创建的变量,如tf.constant,tf.Variable等新建结点的操作才会受到name_scope的影响

    总而言之,name_scope作用比较单一,仅仅是为了更改变量的name属性,便于命名变量。而variable_scope作用就很丰富了,它不仅能够改变变量的name属性,还能够实现变量管理功能。

    with tf.variable_scope("ha"):
        a = tf.constant(2, name="myconstant")  # ha/myconstant:0
        b = tf.get_variable("b", shape=1, dtype=tf.int32, initializer=tf.constant_initializer(2))  # ha/b:0
        print(a.name, b.name)
    

    在改变变量name属性这方面,variable_scope和name_scope基本没有差别,唯一的区别就是get_variable不会受到name_scope的影响,却会受到variable_scope的影响。

    variable_scope更加重要的功能是实现变量管理,其中最突出的一点就是变量共享

    # 下面我们来验证一下变量共享机制
    def get_share_variable(reuse):
        with tf.variable_scope("share", reuse=reuse):
            a = tf.get_variable("a", shape=(1), dtype=tf.int32)
            return a
    
    
    one = get_share_variable(False)
    two = get_share_variable(True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(sess.run([one, two, tf.assign(one, [2])]))
    
    

    再上面的例子中,通过reuse属性可以控制variable_scope中的变量是否需要重新创建,如果指定reuse=false,则必然会执行新建变量的操作。
    上面代码one和two输出值都变成了2,这说明它俩引用的是同一个对象。
    使用variable_scope实现变量共享需要注意以下几点:

    • 重用之前必须保证已经创建过,否则报错
    • 使用时必须指明shape和dtype,否则无法复用已创建的变量

    变量共享机制非常重要。一个非常常用的场景就是:训练完成之后保存模型,加载模型之后整个图已经建立好了,这时就需要通过variable_scope机制复用已经建好的图,然后测试。

    为了验证以上两点,请看下例:

    复用未曾创建过的变量会报错

    with tf.variable_scope("ha", default_name="what", reuse=True):
        try:
            m = tf.get_variable("m")
        except Exception as ex:
            print(ex)  # Variable ha/m does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
    

    它建议我们使用reuse=tf.AUTO_REUSE,这个属性值的含义表示:如果变量存在则复用,不存在则创建之。

    以前,reuse属性的取值为True和False,没有tf.AUTO_REUSE。但是在tensorflow未来的版本中,有可能会把reuse的取值弄成枚举类型。这就略微有点小坑了,不知AUTO_REUSE的使用场景是否对得起这么不优雅的设计。

    • AUTO_REUSE=1
    • REUSE_FALSE = 2
    • REUSE_TRUE = 3

    下面验证第二点,必须指明变量的形状和类型才可以复用

    with tf.variable_scope("ha", default_name="what", reuse=tf.AUTO_REUSE):
        try:
            m = tf.get_variable("m")
            print(m.name)
        except Exception as ex:
            print(ex)  # ValueError: Shape of a new variable (ha/m) must be fully defined, but instead was <unknown>.  这个错误是在说:创建变量必须指明变量的类型
    

    variable_scope的主要作用是变量共享。
    说到变量共享,不得不说tf.get_variable()函数。tf.Variable()是构造函数,一定会创建新的变量。tf.get_variable则可以直接访问已经创建的变量(只能在variable_scope中且reuse=True或者AUTO_REUSE的情况下),也可以创建新的变量(在variable_scope内部或者外面都可以,若是内部,则必须reuse=FALSE或者AUTO_REUSE)。

    通过以上过程可以发现,get_variable跟variable_scope非常合得来。get_variable不一定非在variable_scope中使用,但是不在variable_scope中使用,它就只能用来创建变量而无法复用变量了,因为只有variable_scope才拥有reuse属性。

    # reuse变量的唯一方式就是使用name_scope
    x = tf.Variable(3, False, name='x')
    print(x.name, x.shape)  # x:0 ()
    y = tf.get_variable("x", shape=x.shape)
    print(y.name, y.shape)  # x_1:0 ()
    

    使用get_variable时也有一些微操作,比如指定trainable属性指明该变量是否可以训练。

    with tf.variable_scope("ha", default_name="what", reuse=tf.AUTO_REUSE):
        # 当variable_scope reuse变量时,依旧可以对变量进行一些微操作:设置trainable=False,表示这个节点不可训练
        # 当获取变量时,dtype类型必须对应正确
        a = tf.get_variable("b", trainable=False, dtype=tf.int32)
        print(tf.trainable_variables("ha"))
    
    

    相比name_scope功能的单一,tensorflow对variable_scope玩了很多花样:

    • tf.get_trainable_variables() 获取全部可训练的变量
    • tf.get_variable_scope()获取当前的变量作用域

    为了对比说明问题,下面用嵌套的方式来实现作用域。

    with tf.variable_scope("one"):
        """
        使用name_scope只会影响变量的名字,它要解决的问题是变量重名问题
        """
        with tf.name_scope("two"):
            x = tf.constant(3)  # one/two/Const:0
            print(x.name)
            # variable_scope.name是根目录的名字
            print(tf.get_variable_scope().name, tf.get_variable_scope().original_name_scope)  # 输出为:one    one/
        with tf.variable_scope("three"):
            x = tf.constant(3)  # one/three/Const:0
            print(x.name)
            print(tf.get_variable_scope().name, tf.get_variable_scope().original_name_scope)  # 输出为one/three  one/three/
    

    可见,name_scope对于tf.get_variable_scope()来说几乎是不可见的,但却会对变量的命名产生影响,但却仅仅对变量名产生影响。

    函数不会阻隔with作用域

    # 使用函数依旧不会影响作用域
    def ha():
        x = tf.Variable(3)  # ha_3/Variable:0
        print(x.name)
    
    
    with tf.variable_scope("ha"):  # 这个变量作用域已经定义过好几次了,它的实际名字变成了ha_3
        ha()
    
    

    Tensorflow中每个结点的name都不会重复,如果重复了怎么办?

    
    # 如果重复定义变量
    a = tf.constant(2, name='a')  # a:0
    b = tf.constant(2, name='a')  # a_1:0
    print(a.name, b.name)
    
    # 如果重复定义name_scope
    with tf.name_scope("my"):
        a = tf.constant(2)  # my_1/Const:0
        print(a.name)
    """
    可见tensorflow对于一切重名的东西都会在末尾加上下划线+数字
    """
    with tf.name_scope("my_4"):
        a = tf.constant(2)
        print(a.name)  # my_4/Const:0
    with tf.name_scope("my"):
        a = tf.constant(2)
        print(a.name)  # my_2/Const:0
    with tf.name_scope("my_4"):
        a = tf.constant(2)  # my_4_1/Const:0
        print(a.name)
    

    通过以上例子可以发现,tensorflow对于命名重复问题使用以下规则解决:
    1、要使用的name不存在,可以直接使用
    2、要使用的name已经存在,执行下列循环:

    i=1
    while 1:
        now_name="%s_%d"%(name,i)
        if exists(now_name):
           i+=1
        else:
            return now_name
    
  • 相关阅读:
    SQL Server 2005中的分区表(六):将已分区表转换成普通表
    关于SQL Server中分区表的文件与文件组的删除(转)
    MySQL修改root密码的几种方法
    Aptana 插件 for Eclipse 4.4
    IT励志与指导文章合集(链接)
    正则表达式(转)
    《疯狂原始人》温馨而搞笑片段截图
    指针函数与函数指针的区别(转)
    Linux内核@系统组成与内核配置编译
    2015年我国IT行业发展趋势分析(转)
  • 原文地址:https://www.cnblogs.com/weiyinfu/p/9571986.html
Copyright © 2011-2022 走看看