zoukankan      html  css  js  c++  java
  • Pytorch autograd,backward详解

    转自https://zhuanlan.zhihu.com/p/83172023

    平常都是无脑使用backward,每次看到别人的代码里使用诸如autograd.grad这种方法的时候就有点抵触,今天花了点时间了解了一下原理,写下笔记以供以后参考。以下笔记基于Pytorch1.0

    Tensor

    Pytorch中所有的计算其实都可以回归到Tensor上,所以有必要重新认识一下Tensor。如果我们需要计算某个Tensor的导数,那么我们需要设置其.requires_grad属性为True。为方便说明,在本文中对于这种我们自己定义的变量,我们称之为叶子节点(leaf nodes),而基于叶子节点得到的中间或最终变量则可称之为结果节点。例如下面例子中的x则是叶子节点,y则是结果节点。

    x = torch.rand(3, requires_grad=True)
    y = x**2
    z = x + x

    另外一个Tensor中通常会记录如下图中所示的属性:

    • data: 即存储的数据信息
    • requires_grad: 设置为True则表示该Tensor需要求导
    • grad: 该Tensor的梯度值,每次在计算backward时都需要将前一时刻的梯度归零,否则梯度值会一直累加,这个会在后面讲到。
    • grad_fn: 叶子节点通常为None,只有结果节点的grad_fn才有效,用于指示梯度函数是哪种类型。例如上面示例代码中的y.grad_fn=<PowBackward0 at 0x213550af048>, z.grad_fn=<AddBackward0 at 0x2135df11be0>
    • is_leaf: 用来指示该Tensor是否是叶子节点。

    *图片出处:[PyTorch Autograd]()*

    torch.autograd.backward

    有如下代码:

    x = torch.tensor(1.0, requires_grad=True)
    y = torch.tensor(2.0, requires_grad=True)
    z = x**2+y
    z.backward()
    print(z, x.grad, y.grad)
    
    >>> tensor(3., grad_fn=<AddBackward0>) tensor(2.) tensor(1.)

    可以z是一个标量,当调用它的backward方法后会根据链式法则自动计算出叶子节点的梯度值。

    但是如果遇到z是一个向量或者是一个矩阵的情况,这个时候又该怎么计算梯度呢?这种情况我们需要定义grad_tensor来计算矩阵的梯度。在介绍为什么使用之前我们先看一下源代码中backward的接口是如何定义的:

    torch.autograd.backward(
    		tensors, 
    		grad_tensors=None, 
    		retain_graph=None, 
    		create_graph=False, 
    		grad_variables=None)
    • tensor: 用于计算梯度的tensor。也就是说这两种方式是等价的:torch.autograd.backward(z) == z.backward()
    • grad_tensors: 在计算矩阵的梯度时会用到。他其实也是一个tensor,shape一般需要和前面的tensor保持一致。
    • retain_graph: 通常在调用一次backward后,pytorch会自动把计算图销毁,所以要想对某个变量重复调用backward,则需要将该参数设置为True
    • create_graph: 当设置为True的时候可以用来计算更高阶的梯度
    • grad_variables: 这个官方说法是grad_variables' is deprecated. Use 'grad_tensors' instead.也就是说这个参数后面版本中应该会丢弃,直接使用grad_tensors就好了。

    好了,参数大致作用都介绍了,下面我们看看pytorch为什么设计了grad_tensors这么一个参数,以及它有什么用呢?

    还是用代码做个示例

    x = torch.ones(2,requires_grad=True)
    z = x + 2
    z.backward()
    
    >>> ...
    RuntimeError: grad can be implicitly created only for scalar outputs

    当我们运行上面的代码的话会报错,报错信息为RuntimeError: grad can be implicitly created only for scalar outputs

    上面的报错信息意思是只有对标量输出它才会计算梯度,而求一个矩阵对另一矩阵的导数束手无策。

    [公式]

    那么我们只要想办法把矩阵转变成一个标量不就好了?比如我们可以对z求和,然后用求和得到的标量在对x求导,这样不会对结果有影响,例如:

     

    [公式]

    我们可以看到对z求和后再计算梯度没有报错,结果也与预期一样:

    x = torch.ones(2,requires_grad=True)
    z = x + 2
    z.sum().backward()
    print(x.grad)
    
    >>> tensor([1., 1.])

    我们再仔细想想,对z求和不就是等价于z点乘一个一样维度的全为1的矩阵吗?即 [公式] ,而这个I也就是我们需要传入的grad_tensors参数。(点乘只是相对于一维向量而言的,对于矩阵或更高为的张量,可以看做是对每一个维度做点乘)

    代码如下:

    x = torch.ones(2,requires_grad=True)
    z = x + 2
    z.backward(torch.ones_like(z)) # grad_tensors需要与输入tensor大小一致
    print(x.grad)
    
    >>> tensor([1., 1.])

    弄个再复杂一点的:

    x = torch.tensor([2., 1.], requires_grad=True).view(1, 2)
    y = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
    
    z = torch.mm(x, y)
    print(f"z:{z}")
    z.backward(torch.Tensor([[1., 0]]), retain_graph=True)
    print(f"x.grad: {x.grad}")
    print(f"y.grad: {y.grad}")
    
    >>> z:tensor([[5., 8.]], grad_fn=<MmBackward>)
    x.grad: tensor([[1., 3.]])
    y.grad: tensor([[2., 0.],
            [1., 0.]])

    结果解释如下:

    总结:

    说了这么多,grad_tensors的作用其实可以简单地理解成在求梯度时的权重,因为可能不同值的梯度对结果影响程度不同,所以pytorch弄了个这种接口,而没有固定为全是1。引用自知乎上的一个评论:如果从最后一个节点(总loss)来backward,这种实现(torch.sum(y*w))的意义就具体化为 multiple loss term with difference weights 这种需求了吧。

    torch.autograd.grad

    torch.autograd.grad(
    		outputs, 
    		inputs, 
    		grad_outputs=None, 
    		retain_graph=None, 
    		create_graph=False, 
    		only_inputs=True, 
    		allow_unused=False)

    看了前面的内容后在看这个函数就很好理解了,各参数作用如下:

    • outputs: 结果节点,即被求导数
    • inputs: 叶子节点
    • grad_outputs: 类似于backward方法中的grad_tensors
    • retain_graph: 同上
    • create_graph: 同上
    • only_inputs: 默认为True, 如果为True, 则只会返回指定input的梯度值。 若为False,则会计算所有叶子节点的梯度,并且将计算得到的梯度累加到各自的.grad属性上去。
    • allow_unused: 默认为False, 即必须要指定input,如果没有指定的话则报错。

    参考

  • 相关阅读:
    CF1480C Searching Local Minimum
    如何根据IP地址查到主机名
    转贴:关于内部重定向(forward)和外部重定向(redirect)
    读懂vmstat
    Javascript在网页页面加载时的执行顺序
    安全测试学习笔记一
    Linux文件查找命令find,xargs详述
    mvn常用命令
    prototype.js 让你更深入的了解javascript的面向对象特性
    【转】Velocity研究学习文档
  • 原文地址:https://www.cnblogs.com/skydaddy/p/11596056.html
Copyright © 2011-2022 走看看