zoukankan      html  css  js  c++  java
  • tensorflow基础【2】-Variable 详解

    Variable 的主要作用是维护特定节点的状态,如深度学习模型参数

    创建_基础操作

    创建 Variable 有两种方式

    tf.Variable 

    创建唯一变量

    class VariableV1(Variable):
        def __init__(self,  
               initial_value=None,        # -- 变量值
               trainable=None,            # 该变量是否需要训练,或者说是否能被优化器更新
               collections=None,
               validate_shape=True,
               caching_device=None,
               name=None,                 # -- 变量的名字
               variable_def=None,
               dtype=None,                # -- 变量类型
               expected_shape=None,
               import_scope=None,
               constraint=None,
               use_resource=None,
               synchronization=VariableSynchronization.AUTO,
               aggregation=VariableAggregation.NONE,
               shape=None):               # -- 变量尺寸
            pass

    tf.Variable 是一个操作 (op),返回值是 Variable;

    d1 = tf.Variable(2)
    d2 = tf.Variable(3, dtype=tf.int32, name='int')
    d3 = tf.Variable(4., dtype=tf.float32, name='float')
    d4 = tf.add(d1, d2)
    d5 = d1 + d2
    # d6 = tf.add(d1, d3)     ### 不同类型的数据不能运算
    
    init = tf.global_variables_initializer()        ### 变量必须初始化
    
    sess1 = tf.Session()
    sess1.run(init)
    print(sess1.run(d4))        # 5
    print(sess1.run(d5))        # 5
    # print(sess1.run(d6))
    print(type(d5))             # <class 'tensorflow.python.framework.ops.Tensor'>

    tf.get_variable 

    获取或者创建共享变量:获取指定属性(如name)的现有变量,如果该变量不存在,就新建一个变量;

    d1 = tf.get_variable('d1', shape=[2, 3], initializer=tf.ones_initializer)
    d2 = tf.get_variable('d2', shape=[3, 2], initializer=tf.zeros_initializer)
    sess3 = tf.Session()
    sess3.run(tf.global_variables_initializer())
    print(sess3.run(d1))
    # [[1. 1. 1.]
    #  [1. 1. 1.]]
    print(sess3.run(d2))

    tf.Variable VS tf.get_variable

    tf.Variable 保证了变量的唯一性:当它检测到有命名冲突时,会自动处理冲突;

    tf.get_variable 用于共享变量:当它检测到有命名冲突时,会报错;

    ### Variable
    w_1 = tf.Variable(3, name="w_1")
    w_2 = tf.Variable(1, name="w_1")
    print(w_1.name)     # w_1:0
    print(w_2.name)     # w_1_1:0   系统检测到命名冲突,会自动处理,把w_1变成w_1_1,保证了变量的唯一性
    
    ### get_variable
    w_1 = tf.get_variable(name="g_1", initializer=1)
    # w_2 = tf.get_variable(name="w_1",initializer=2)     # 这句报错,系统检测到命名冲突,会直接报错
    # 错误信息
    # ValueError: Variable w_1 already exists, disallowed. Did you mean to set reuse=True in VarScope?

    二者混用比较复杂,具体用法参考我的博客 变量与作用域

    初始化

    Variable 在参与计算之前必须初始化, 两种方式

    d1 = tf.Variable(1)
    print(d1)       # <tf.Variable 'Variable:0' shape=() dtype=int32_ref>
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        print(sess.run(d1))     # 1
    
    ###
    d2 = tf.Variable(1)
    with tf.Session() as sess:
        tf.global_variables_initializer().run()

    如果初始化后又创造了新变量,需要重新初始化

    上面是初始化全部变量,还有以下方法

    tf.local_variables_initializer()    ### 初始化局部变量
    tf.initialize_variables(var_list=, name=)   ### 初始化指定变量

    属性方法

    Variable 所有属性方法如下

    ['SaveSliceInfo', 'aggregation', 'assign', 'assign_add', 'assign_sub', 'batch_scatter_update', 'constraint', 'count_up_to', 'device', 'dtype', 'eval', 'from_proto', 'gather_nd', 'get_shape', 
    'graph', 'initial_value', 'initialized_value', 'initializer', 'load', 'name', 'op', 'read_value', 'scatter_add', 'scatter_nd_add', 'scatter_nd_sub', 'scatter_nd_update', 'scatter_sub', 
    'scatter_update', 'set_shape', 'shape', 'sparse_read', 'synchronization', 'to_proto', 'trainable', 'value']

    变量名

    节点名,也是变量名,如果创建 Variable 时显式的设置了 name,则取该 name,如果没有,则以 Variable_1 格式递增下标

    d1 = tf.Variable(tf.zeros(2,2))
    d2 = tf.Variable(2., dtype=tf.float32, name='d2')
    d3 = tf.Variable(3)
    print(d1)               # <tf.Variable 'Variable:0' shape=() dtype=float32_ref>
    print(d1.op.name)       # Variable_1
    print(d2.op.name)       # d2
    print(d3.op.name)       # Variable_2

    内存机制

    tf.Variable 创建的变量与张量一样,可以作为操作的输入和输出,不同之处在于:

    1. 张量的生命周期通常依赖计算的完成而结束,内存随即释放

    2. 变量常驻内存,随计算同步更新,不随计算结束而结束

    d1 = tf.Variable(2.)
    d2 = tf.constant(42.)
    print(d2)       # Tensor("Const:0", shape=(), dtype=float32)
    d3 = tf.assign_add(d1, 1.)
    d4 = tf.add(d2, 1.)
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        for i in range(2):
            print(sess.run(d3))     # 3.0,  4.0     ### 循环中 variable 状态持续更新,内存不释放
            print(sess.run(d4))     # 43.0, 43.0    ### 循环中 constant 状态不更新,内存实时释放
        print(d1, sess.run(d1))     # <tf.Variable 'Variable:0' shape=() dtype=float32_ref> 4.0     ### 循环结束之后 variable 状态保持,内存未释放
        print(d2, sess.run(d2))     # Tensor("Const:0", shape=(), dtype=float32) 42.0               ### 循环结束之后 constant 状态恢复,内存释放

    Variable 的这个特性常用于模型参数的迭代

    Variable 赋值

    变量的赋值不能直接用 =,有几种方式:tf.assign、var.assign、tf.assign_add

    d1 = tf.Variable(2)
    d2 = tf.Variable(3, dtype=tf.int32, name='int')
    d3 = tf.Variable(4., dtype=tf.float32, name='float')
    ## method1
    # d4 = tf.assign(d2, d3)      ### 两个变量数据类型要一致
    d5 = tf.assign(d2, d1)
    # d7 = tf.assign(d6, d3)      ### 被赋值的变量必须事先存在
    ## method2
    d8 = d2.assign(100)
    ## method3:加个数并赋值
    d9 = tf.assign_add(d2, 50)
    
    with tf.Session() as sess2:
        tf.global_variables_initializer().run()
        print(sess2.run(d5))        # 2     d5 被赋值了,等于 d2
        print(sess2.run(d2))        # 2     真正的 d2 也变了
        print(sess2.run(d8))        # 100
        print(sess2.run(d9))        # 150
        print(sess2.run(tf.assign_add(d2, 3)))  # 153
        print(sess2.run(tf.assign(d2, 3)))  # 3
        print(sess2.run(tf.assign(d2, d1))) # 2

    trainable

    trainable 属性指定变量是否参与训练,或者说是否能被优化器更新,类似于 PyTorch 中的 requires_grad;

    False 代表不参与训练,默认为 True;

    trainable 为只读属性,只在创建 Variable 时生效,后期无法更改;

    在创建优化器 Optimizer 的 minimize 张量时,tf 会把所有可训练的 Variable 收集到 trainable_variables 中,此后增加或者删除 可训练的变量,trainable_variables 不会变化;

    x = tf.Variable(3.0, dtype=tf.float32, trainable=False)     ### x 的 trainable 为 F,不参与训练
    y = tf.Variable(13.0, dtype=tf.float32)         ### 参与训练
    train_op = tf.train.AdamOptimizer(0.01).minimize(tf.abs(y - x))
    with tf.Session()as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(5):
            _, xx, yy = sess.run([train_op, x, y])
            print('epoch', i, xx, yy)  # 观察 x 和 y 的变化
    
    # epoch 0 3.0 12.99
    # epoch 1 3.0 12.98
    # epoch 2 3.0 12.969999
    # epoch 3 3.0 12.959999
    # epoch 4 3.0 12.949999

    tf.trainable_variables and tf.all_variables

    tf.trainable_variables()   返回的是 所有需要训练的变量列表

    tf.all_variables()      返回的是 所有变量的列表

    v1 = tf.Variable(0, name='v1')
    v2 = tf.Variable(tf.constant(5, shape=[1], dtype=tf.float32), name='v2')
    global_step = tf.Variable(6, name='global_step', trainable=False)       # 声明不是训练变量
    
    for ele1 in tf.trainable_variables():
        print(ele1.name)
    # v1:0
    # v2:0
    
    for ele2 in tf.all_variables():
        print(ele2.name)
    # v1:0
    # v2:0
    # global_step:0

    Variable 的保存和加载

    保存和加载都需要创建 Saver 对象,然后调用 save 保存 和 restore 加载

    from tensorflow.core.protobuf import saver_pb2
    class Saver(object):
        def __init__(self,
                   var_list=None,
                   reshape=False,
                   sharded=False,
                   max_to_keep=5,
                   keep_checkpoint_every_n_hours=10000.0,
                   name=None,
                   restore_sequentially=False,
                   saver_def=None,
                   builder=None,
                   defer_build=False,
                   allow_empty=False,
                   write_version=saver_pb2.SaverDef.V2,
                   pad_step_number=False,
                   save_relative_paths=False,
                   filename=None):
            pass
        def save(self,
               sess,
               save_path,
               global_step=None,
               latest_filename=None,
               meta_graph_suffix="meta",
               write_meta_graph=True,
               write_state=True,
               strip_default_attrs=False,
               save_debug_info=False):
            pass
        def restore(self, sess, save_path):
            """Restores previously saved variables."""
            pass

     

    创建 Saver 对象

    var_list:被保存或者加载的 variable,该参数的取值有几种形式:

    1. 默认为 None,即针对所有 Variable
    2. list 格式,指定 variable,variable 的命名默认为 v1、v2...
    3. dict 格式,指定 name 和 variable

    save 方法

    save_path:指定存储路径,一般以 ckpt(checkpoint) 结尾

    global_step:指定全局阶段,实际上就是个标记,通常是用一个数字指定 variable 在哪个阶段保存的,这个数字位于 filename 之后,具体见例子

    restore 方法

    save_path:注意这个路径要看 save 时的 global_step

    d1 = tf.Variable(1.)
    d2 = tf.Variable(2., dtype=tf.float32, name='d2')
    init = tf.global_variables_initializer()
    
    ### 初始化 Saver 对象
    saver = tf.train.Saver()            ### 保存所有变量
    saver1 = tf.train.Saver([d1, d2])       ### list 指定变量,变量名默认为 v1 v2 递增
    saver2 = tf.train.Saver({'v1': d1, 'v2':d2})      ### dict 指定变量和变量名
    
    with tf.Session() as sess:
        sess.run(init)
        ### save 方法保存变量
        saver.save(sess, './var/all.ckpt')
        saver1.save(sess, './var/list.ckpt', global_step=0)
        print(saver2.save(sess, './var/dict.ckpt', global_step=1))      # ./var/dict.ckpt-1
        sess.run(tf.assign_add(d2, 3.))     ### 保存之后再改变
        print(sess.run(d2))     # 5.0
    
        ### 加载变量 1:同一个 saver、sess
        saver2.restore(sess, './var/dict.ckpt-1')
        print(sess.run(d2))     # 2.0       ### 加载的是改变前的值,说明保存成功
    
    ### 加载变量 2:同一个 saver,不同的 sess
    with tf.Session() as sess:
        saver2.restore(sess, './var/dict.ckpt-1')
        print(sess.run(d2))     # 2.0
    
    ### 加载变量 3:不同的 saver,不同的 sess
    saver2 = tf.train.Saver({'v2':d2})
    with tf.Session() as sess:
        saver2.restore(sess, './var/dict.ckpt-1')
        print(sess.run(d2))     # 2.0

    可见,保存与加载相互独立

    上述代码保存 variable 结果如下

    1. global_step 被加到 filename 之后

    2. save 会生成 4 个文件 data、index、meta、checkpoint

    • data:存放模型参数
    • meta:存放计算图
    • checkpoint:记录模型存储的路径,model_checkpoint_path 代表最新的模型存储路径,all_model_checkpoint_paths 代表所有模型的存储路径

    3. 最多只保存近 5 次的存储

    4. 多次保存只有一个 checkpoint

    参考资料:

    https://www.cnblogs.com/weiyinfu/p/9973022.html  tensorflow动态设置trainable

  • 相关阅读:
    JSP error: Only a type can be imported
    关于JAVA插入Mysql数据库中文乱码问题解决方案
    MySQL SQL优化——分片搜索
    myeclipse 调试JSP页面
    jsp:usebean 常用注意事项
    spring XML格式
    VB 要求对象
    VB 对象变量或with块变量未设置
    Spring依赖注入
    Spring 读取XML配置文件的两种方式
  • 原文地址:https://www.cnblogs.com/yanshw/p/12341295.html
Copyright © 2011-2022 走看看