zoukankan      html  css  js  c++  java
  • pytorch反向传播两次,梯度相加,retain_graph=True

    pytorch是动态图计算机制,也就是说,每次正向传播时,pytorch会搭建一个计算图,loss.backward()之后,这个计算图的缓存会被释放掉,下一次正向传播时,pytorch会重新搭建一个计算图,如此循环。

    在默认情况下,PyTorch每一次搭建的计算图只允许一次反向传播,如果要进行两次反向传播,则需要在第一次反向传播时设置retain_graph=True,即 loss.backwad(retain_graph=True) ,这样做可以保留动态计算图,在第二次反向传播时,将自动和第一次的梯度相加。

    示例:

    import torch
    
    input_ = torch.tensor([[1., 2.], [3., 4.]], requires_grad=False)
    w1 = torch.tensor(2.0, requires_grad=True)
    w2 = torch.tensor(3.0, requires_grad=True)
    
    l1 = input_ * w1
    l2 = l1 + w2
    loss1 = l2.mean()
    loss1.backward(retain_graph=True)
    
    print(w1.grad)  # 输出:tensor(2.5)
    print(w2.grad)  # 输出:tensor(1.)
    
    loss2 = l2.sum()
    loss2.backward()
    
    print(w1.grad)  # 输出:tensor(12.5)
    print(w2.grad)  # 输出:tensor(5.)

    示例中的梯度推导很简单,我在这篇博客里推了一下。从输出结果来看,程序确实是把两次的梯度加起来了。

    附注:如果网络要进行两次反向传播,却没有用retain_graph=True,则运行时会报错:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

  • 相关阅读:
    备份
    >> 与 > >
    为什么需要htons(), ntohl(), ntohs(),htons() 函数
    小技巧
    C++头文件
    宏定义中的#,##操作符和... and _ _VA_ARGS_ _与自定义调试信息的输出
    OpenCV摄像头简单程序
    [转]让Linux的tty界面支持中文
    opencv 2 computer vision application programming第四章翻译
    OpenCV条码(6)简单实现
  • 原文地址:https://www.cnblogs.com/picassooo/p/13818952.html
Copyright © 2011-2022 走看看