zoukankan      html  css  js  c++  java
  • pytorch(3)----基本类型Autograd计算图

    pytorch 0.4版本之后,torch.autograd.Variable 和torch.Tensor 进行了整合。

    基本内容

    1、 创建可以自动求导的Tensor,默认为false。requires_grad属性

    2、 Tensor 的连个属性:

          grad: 记录该Tensor对应的梯度

          grad_fn: 指向function 对象,记录了该Tensor的操作

    3、计算图,根节点、叶子节点、中间节点;判断是否是叶子节点(.is_leaf)

    4、计算叶节点梯度,根节点使用 .backward()函数

    代码示例

    1、

    # requires_grad属性,设置为可求导
    a = torch.randn(2,2,requires_grad=True)  #默认为false
    b = torch.randn(2,2)
    
    print(a.requires_grad,b.requires_grad)
    # 使用 .requires_grad_() 将requires_grad 设置为true
    # 等价与 b.requires_grad=True
    b.requires_grad_()
    print(b.requires_grad)

    2、

    #Tensor 的两个属性
    # grad: 记录该Tensor对应的梯度
    # grad_fn: 指向function 对象,记录了该Tensor进行过的操作
    
    print(a)
    print(b)
    c = a+b
    print(c)
    print(c.requires_grad)
    
    print(a.grad_fn)
    print(b.grad_fn)
    print(c.grad_fn)
    
    d = c.detach() #.detach() 获取数据,类似于.data(),但前者更安全,后者不会修改autograd的追踪信息
    print(d)
    print(d.requires_grad)

    3、

    # 计算图
    x = torch.randn(1)
    w = torch.ones(1,requires_grad=True)
    b = torch.ones(1,requires_grad=True)
    
    print(x.is_leaf,w.is_leaf,b.is_leaf)
    
    y = w*x
    z = y+b
    print(y.is_leaf,z.is_leaf)
    print(y.grad_fn,z.grad_fn)
    
    # 对根节点使用 .backward()可以获得 叶子节点的 梯度
    z.backward(retain_graph=True)
    print(w.grad)
    print(b.grad)
  • 相关阅读:
    hive按月/周统计
    mysql按周/月/年统计数据
    Linux命令-查看目录下文件个数
    hive终端常用指令
    Sql 对varchar格式进行时间排序
    Python学习笔记--2.3 list列表操作(切片)
    Python学习笔记--2.2 list列表练习
    Python学习笔记--2.1 list列表操作(增删改查)
    Python学习笔记--1 基础&一个登陆小程序
    接口测试基础知识
  • 原文地址:https://www.cnblogs.com/feihu-h/p/12305683.html
Copyright © 2011-2022 走看看