zoukankan      html  css  js  c++  java
  • Tensorflow03-可训练变量与自动求导机制

    一、Variable 可训练变量

    • 对Tensor对象的进一步封装
    • 在模型训练过程中自动记录梯度信息,由算法自动优化
    • 可以被训练的变量
    • 在机器学习中作为模型参数

    1、创建 Variable 对象

    使用如下命令进行 variable 对象的创建。

    tf.Variable(initial_value, dtype)

    其中 initial_value 是传入的数值型数据,可以传入数字,python列表,ndarray对象 或者 tensor对象。dtype 为类型,默认是float32。

     2、Variable 对象的属性和方法

    trainable 属性:表示是否可以被训练。

    ResourceVariable。

    可训练变量的赋值

     

     使用函数 assign() assign_add() assign_sub()函数,分别对 variable 对象重新赋值,加 减 等操作。注:是返回一个操作后的对象,原变量不变。

    二、Tensorflow的自动求导 

    1、GradientTape 类

    tensorflow 提供了一个专门用于求导的类,GradientTape,可以形象的理解为记录数据梯度的磁带,通过它,可以实现对变量的自动求导和监视 。

    with tf.GradientTape() as tape:
        函数表达式
    grad = tape.gradient(函数, 自变量)

    GradientTape 类实现了上下文管理器,它能够监视 with语句块所有的变量和计算过程,并把他们自动记录在梯度带中。

    在上述代码中,tf.GradientTape() 为构造函数,首先使用它来创建梯度带对象 tape,同时tape 也是上下文管理器对象。然后把函数表达式写在 with 语句块中,监视要求导的变量。最后使用tape对象的gradient函数求得导数。gradient函数的第一个参数是被求导的函数,第二次参数是被求导的自变量。

     例如,求 y = x2x=3,即 x = 3 的导数的值,具体步骤如下:

    2、GradientTape 参数

     GradientTape 有两个参数,GradientTape(persistent, watch_accessed_variables),这两个参数都是 布尔类型。

    (1)persistent 参数

    第一个参数默认为 False,表示这个tape默认只能使用一次,求导之后就被销毁。

    如下,要求两个函数的导数的值,需要指定 persistent为True,并且在使用完后,需要手动销毁 tape,del tape

     

     (2)watch_accessed_variables 参数

    第二个参数表示自动监视所有的可训练变量,也就是 Variable 对象,默认为 True。如果设置为 False,就无法自动监视,如下图。

     在这种情况下,可以手动添加对变量的监视,使用的是 watch() 函数。

     GradientTape 类默认自动监视所有的可训练变量,使用watch函数还可以监视非可训练对象,比如,可以讲上图的 x 由 Variable对象换为 Tensor对象,仍然可以得到正确结果。

    3、多元函数求偏导数

    函数 tape.gradient() 的第二个参数自变量,可以是一个,可以是多个。如果对多个自变量求偏导数时,只要把所有的自变量都放到一个列表里就可以了。如下图。

     也可将画线一行代码写成如下格式:

    df_dx = tape.gradient(f,x)
    df_dy = tape.gradient(f,y)

    但最后需要使用 del tape 进行资源释放。

    4、求二阶导数

     5、对向量求偏导

  • 相关阅读:
    算法笔记--中国剩余定理
    算法笔记--辛普森积分公式
    算法笔记--数学之不定方程解的个数
    算法笔记--卢卡斯定理
    洛谷 P3808 【模板】AC自动机(简单版)洛谷 P3796 【模板】AC自动机(加强版)
    hihocoder #1419 : 后缀数组四·重复旋律4
    codevs 3044 矩形面积求并 || hdu 1542
    Stamps ans Envelope Sive UVA
    洛谷 P2061 [USACO07OPEN]城市的地平线City Horizon
    bzoj 3277: 串
  • 原文地址:https://www.cnblogs.com/dongao/p/14365269.html
Copyright © 2011-2022 走看看