zoukankan      html  css  js  c++  java
  • PyTorch grad_fn的作用以及RepeatBackward, SliceBackward示例

    变量.grad_fn表明该变量是怎么来的,用于指导反向传播。例如loss = a+b,则loss.gard_fn为<AddBackward0 at 0x7f2c90393748>,表明loss是由相加得来的,这个grad_fn可指导怎么求a和b的导数

    程序示例:

    import torch
    
    w1 = torch.tensor(2.0, requires_grad=True)
    a = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
    tmp = a[0, :]
    tmp.retain_grad()   # tmp是非叶子张量,需用.retain_grad()方法保留导数,否则导数将会在反向传播完成之后被释放掉
    b = tmp.repeat([3, 1])
    b.retain_grad()
    loss = (b * w1).mean()
    loss.backward()
    
    print(b.grad_fn)    # 输出: <RepeatBackward object at 0x7f2c903a10f0>
    print(b.grad)       # 输出: tensor([[0.3333, 0.3333],
                        #               [0.3333, 0.3333],
                        #               [0.3333, 0.3333]])
    
    print(tmp.grad_fn)    # 输出:<SliceBackward object at 0x7f2c90393f60>
    print(tmp.grad)       # 输出:tensor([1., 1.])
    
    
    print(a.grad)     # 输出:tensor([[1., 1.],
                      #              [0., 0.]])

    手动推导:

    手动推导的结果和程序的结果是一致的。

  • 相关阅读:
    windows下安装mysql教程
    git基本操作
    JDK8,Optional
    重新安装MySQL 8出现的问题
    HTML5学习:缩略图
    HTML5学习:表格
    MySQL常用命令
    Django学习:创建admin后台管理站点
    Django学习:连接Mysql数据库
    Django学习:创建第一个app
  • 原文地址:https://www.cnblogs.com/picassooo/p/13757403.html
Copyright © 2011-2022 走看看