zoukankan      html  css  js  c++  java
  • tf.variable和tf.get_Variable以及tf.name_scope和tf.variable_scope的区别

    在训练深度网络时,为了减少需要训练参数的个数(比如具有simase结构的LSTM模型)、或是多机多卡并行化训练大数据大模型(比如数据并行化)等情况时,往往需要共享变量。另外一方面是当一个深度学习模型变得非常复杂的时候,往往存在大量的变量和操作,如何避免这些变量名和操作名的唯一不重复,同时维护一个条理清晰的graph非常重要。
    因此,tensorflow中用tf.Variable(),tf.get_variable(),tf.Variable_scope(),tf.name_scope()几个函数来实现:


    一、tf.Variable(<variable_name>),tf.get_variable(<variable_name>)的作用与区别:

    tf.Variable(<variable_name>)和tf.get_variable(<variable_name>)都是用于在一个name_scope下面获取或创建一个变量的两种方式,区别在于:

    1. tf.Variable(<variable_name>)会自动检测命名冲突并自行处理,但tf.get_variable(<variable_name>)则遇到重名的变量创建且变量名没有设置为共享变量时,则会报错。
    2. tf.Variable(<variable_name>)用于创建一个新变量,在同一个name_scope下面,可以创建相同名字的变量,底层实现会自动引入别名机制,两次调用产生了其实是两个不同的变量。
      tf.get_variable(<variable_name>)用于获取一个变量,并且不受name_scope的约束。当这个变量已经存在时,则自动获取;如果不存在,则自动创建一个变量。
    二、tf.name_scope(<scope_name>)与tf.variable_scope(<scope_name>)的作用与区别:

    tf.name_scope(<scope_name>):主要用于管理一个图里面的各种op,返回的是一个以scope_name命名的context manager。一个graph会维护一个name_space的
    堆,每一个namespace下面可以定义各种op或者子namespace,实现一种层次化有条理的管理,避免各个op之间命名冲突。

    tf.variable_scope(<scope_name>):一般与tf.name_scope()配合使用,用于管理一个graph中变量的名字,避免变量之间的命名冲突,tf.variable_scope(<scope_name>)允许在一个variable_scope下面共享变量。

    代码示例:

    在 tf.name_scope下时,tf.get_variable()创建的变量名不受 name_scope 的影响,而且在未指定共享变量时,如果重名会报错,tf.Variable()会自动检测有没有变量重名,如果有则会自行处理。

    import tensorflow as tf
    
    with tf.name_scope('name_scope_x'):
        var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
        var3 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
        var4 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(var1.name, sess.run(var1))
        print(var3.name, sess.run(var3))
        print(var4.name, sess.run(var4))
    # 输出结果:
    # var1:0 [-0.30036557]   可以看到前面不含有指定的'name_scope_x'
    # name_scope_x/var2:0 [ 2.]
    # name_scope_x/var2_1:0 [ 2.]  可以看到变量名自行变成了'var2_1',避免了和'var2'冲突
    

    如果使用tf.get_variable()创建变量,且没有设置共享变量,重名时会报错

    import tensorflow as tf
    
    with tf.name_scope('name_scope_1'):
        var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
        var2 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(var1.name, sess.run(var1))
        print(var2.name, sess.run(var2))
    
    # ValueError: Variable var1 already exists, disallowed. Did you mean 
    # to set reuse=True in VarScope? Originally defined at:
    # var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
    

    所以要共享变量,需要使用tf.variable_scope()

    import tensorflow as tf
    
    with tf.variable_scope('variable_scope_y') as scope:
        var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
        scope.reuse_variables()  # 设置共享变量
        var1_reuse = tf.get_variable(name='var1')
        var2 = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)
        var2_reuse = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print(var1.name, sess.run(var1))
        print(var1_reuse.name, sess.run(var1_reuse))
        print(var2.name, sess.run(var2))
        print(var2_reuse.name, sess.run(var2_reuse))
    # 输出结果:
    # variable_scope_y/var1:0 [-1.59682846]
    # variable_scope_y/var1:0 [-1.59682846]   可以看到变量var1_reuse重复使用了var1
    # variable_scope_y/var2:0 [ 2.]
    # variable_scope_y/var2_1:0 [ 2.]
    
  • 相关阅读:
    ESB企业服务总线
    OpenStack的架构详解[精51cto]
    用MSBuild和Jenkins搭建持续集成环境(1)[收集]
    Hmac算法
    自定义JDBCUtils工具类
    读取JDBC配置文件的二种方式
    哈希算法
    BouncyCastle
    签名算法
    3种查看java字节码的方式
  • 原文地址:https://www.cnblogs.com/guoyaohua/p/8081192.html
Copyright © 2011-2022 走看看