zoukankan      html  css  js  c++  java
  • Tensorflow中的变量

    从初识tf开始,变量这个名词就一直都很重要,因为深度模型往往所要获得的就是通过参数和函数对某一或某些具体事物的抽象表达。而那些未知的数据需要通过学习而获得,在学习的过程中它们不断变化着,最终收敛达到较好的表达能力,因此它们无疑是变量。

    正如三位大牛所言:深度学习是一种多层表示学习方法,用简单的非线性模块构建而成,这些模块将上一层表示转化成更高层、更抽象的表示。

    原文如下: Deep-learning methods are representation-learning methods with multiple levels of representation, obtained by composing simple but non-linear modules that each transform the representation at one level (starting with the raw input) into a representation at a higher, slightly more abstract level.

    必读文献之一:Deep Learning

    当训练模型时,用变量来存储和更新参数。变量包含张量 (Tensor)存放于内存的缓存区。建模时它们需要被明确地初始化,模型训练后它们必须被存储到磁盘。这些变量的值可在之后模型训练和分析是被加载。

    通过之前的学习,可以例举出以下tf的函数:

    var = tf.get_variable(name, shape, initializer=initializer)
    global_step = tf.Variable(0, trainable=False)
    init = tf.initialize_all_variables()#高版本tf已经舍弃该函数,改用global_variables_initializer()
    saver = tf.train.Saver(tf.global_variables())
    initial = tf.constant(0.1, shape=shape)
    initial = tf.truncated_normal(shape, stddev=0.1)
    tf.global_variables_initializer()

    上述函数都和tf的参数有关,主要包含在以下两类中:

    从变量存在的整个过程来看上述两类:变量的创建、初始化、更新、保存和加载。

    •  创建

    当创建一个变量时,将一个张量作为初始值传入构造函数Variable()。tf提供了一系列操作符来初始化张量,初始值是常量或是随机值。注意,所有这些操作符都需要你指定张量的shape。变量的shape通常是固定的,但TensorFlow提供了高级的机制来重新调整其行列数。

    可以创建以下类型的变量:常数、序列、随机数。例如:

    #-*-coding:utf-8-*-
    #创建常数变量的例子
    import tensorflow as tf
    #常数constant
    tensor=tf.constant([[1,3,5],[8,0,7]])
    #创建tensor值为0的变量
    x = tf.zeros([3,4])
    #创建tensor值为1的变量
    x1 = tf.ones([3,4])
    #创建shape和tensor一样的但是值全为0的变量
    y = tf.zeros_like(tensor)
    #创建shape和tensor一样的但是值全为1的变量
    y1 = tf.ones_like(tensor)
    #用8填充shape为2*3的tensor变量
    z = tf.fill([2,3],8)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    print (sess.run(x))  
    print (sess.run(y))  
    print (sess.run(tensor))
    print (sess.run(x1))
    print (sess.run(y1))
    print (sess.run(z))
    #-*-coding:utf-8-*-
    #创建数字序列变量的例子
    import tensorflow as tf
    
    x=tf.linspace(10.0, 15.0, 3, name="linspace")
    y=tf.lin_space(10.0, 15.0, 3)
    w=tf.range(8.0, 13.0, 2.0)
    z=tf.range(3, -3, -2)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    
    print (sess.run(x))  
    print (sess.run(y))
    print (sess.run(w))
    print (sess.run(z))

     随机常量的创建详见tensorflow随机张量创建

    #创建随机变量的例子
    weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                          name="weights")
    • 初始化

    变量的初始化必须在模型的其它操作运行之前先明确地完成。最简单的方法就是添加一个给所有变量初始化的操作,并在使用模型之前首先运行那个操作。使用tf.global_variables_initializer()添加一个操作对变量做初始化。例如:

    # Create two variables.
    weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                          name="weights")
    biases = tf.Variable(tf.zeros([200]), name="biases")
    ...
    # Add an op to initialize the variables.
    init = tf.global_variables_initializer()
    
    # Later, when launching the model
    with tf.Session() as sess:
      # Run the init operation.
      sess.run(init)
      ...
      # Use the model
      ...

    有时候会需要用另一个变量的初始化值给当前变量初始化。由于tf.global_variables_initializer()是并行地初始化所有变量,所以用其它变量的值初始化一个新的变量时,使用其它变量的initialized_value()属性。你可以直接把已初始化的值作为新变量的初始值,或者把它当做tensor计算得到一个值赋予新变量。例如:

    # Create a variable with a random value.
    weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                          name="weights")
    # Create another variable with the same value as 'weights'.
    w2 = tf.Variable(weights.initialized_value(), name="w2")
    # Create another variable with twice the value of 'weights'
    w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")

    assign()函数也有初始化的功能,详见assign()函数

    另外,这里还应该说明的是还有三种读取数据的方法:Feeding、文件中读取、加载预训练数据,它们都属于给变量初始化的方式。为了不至于引起混淆,必须说明的是常量也是变量,而三种读取数据的方法,都是读取常量的方法,但依然是初始化的一种常见方式。详见Tensorflow数据读取的方式

    •  更新

    虽然assign()函数有对变量进行更新的作用,但是此处探讨的更新却不是如此简单。而事实上,我们不需要做什么具体的事情,因为tf是自动求导求梯度,根据代价函数自动更新参数的。这是全局参数的更新,也是tf学习的机制自动确定的。那tf如何知道哪个究竟是变量,哪个究竟又是常量呢?很简单,tf.variable()里面有个布尔型的参数trainable,表示这个参数是不是需要学习的变量,而它默认为true,因此很容易被忽略,就这样tf图会把它加入到GraphKeys.TRAINABLE_VARIABLES,从而对其进行更新。 

    • 保存

    对于训练的变量,成功的话,都是有意义的,需要将其保存在文件里,方便以后的测试和再训练,这就是weights文件,是必不可少的。

    在cifar10项目中当然也有保存这些变量,例如:

    # Create a saver.
        saver = tf.train.Saver(tf.global_variables())
    ......
    # Save the model checkpoint periodically.
        if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
            checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=step)

    Saver类把变量存储在二进制文件checkpoint里,主要包含从变量名到tensor值的映射关系。

    • 加载

    加载变量和保存变量是正反的过程,保存变量是要把模型里的变量信息保存到weights文件里,而加载变量就是要把这些有意义的变量值从weights文件加载到模型里。

    同理在cifar10项目中测试训练的模型时加载了上述保存的变量,例如:

      with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
          # Restores from checkpoint
          saver.restore(sess, ckpt.model_checkpoint_path)
          # Assuming model_checkpoint_path looks something like:
          #   /my-favorite-path/cifar10_train/model.ckpt-0,
          # extract global_step from it.
          global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
        else:
          print('No checkpoint file found')
          return

    如果想选择和加载某一部分变量,则可以通过变量名索引,例如:

    # Create some variables.
    v1 = tf.Variable(..., name="v1")
    v2 = tf.Variable(..., name="v2")
    ...
    # Add ops to save and restore only 'v2' using the name "my_v2"
    saver = tf.train.Saver({"my_v2": v2})
    # Use the saver object normally after that.
    ...

    这里my_v2就是新的变量名,而v2就是它的值。

  • 相关阅读:
    优先队列
    Problem W UVA 662 二十三 Fast Food
    UVA 607 二十二 Scheduling Lectures
    UVA 590 二十一 Always on the run
    UVA 442 二十 Matrix Chain Multiplication
    UVA 437 十九 The Tower of Babylon
    UVA 10254 十八 The Priest Mathematician
    UVA 10453 十七 Make Palindrome
    UVA 10163 十六 Storage Keepers
    UVA 1252 十五 Twenty Questions
  • 原文地址:https://www.cnblogs.com/cvtoEyes/p/8998341.html
Copyright © 2011-2022 走看看