zoukankan      html  css  js  c++  java
  • pytorch detach函数

    用于截断反向传播

    detach()源码:

    def detach(self):
        result = NoGrad()(self)  # this is needed, because it merges version counters
        result._grad_fn = None
        return result

    它的返回结果与调用者共享一个data tensor,且会将grad_fn设为None,这样就不知道该Tensor是由什么操作建立的,截断反向传播

    这个时候再一个tensor使用In_place操作会导致另一个的data tensor也会发生改变

    import torch
    
    a = torch.tensor([1, 2, 3.], requires_grad=True)
    out = a.sigmoid()
    print(out)#tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
    
    c = out.detach()
    print(c)#tensor([0.7311, 0.8808, 0.9526])

    这个时候可以看到,c和out的区别就是一个有grad_fn,一个没有grad_fn

    执行out.sum().backward()没有问题,但执行c.sum().backward()报错:

    RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

    这个时候不论是对out还是对c进行inplace操作改变它们的data,这个改动会被autograd追踪,这个时候再执行out.sum().backward()会报错

    假设对out进行inplace操作,会出现:

    out.zero_()
    #tensor([0., 0., 0.], grad_fn=<ZeroBackward>)
    
    out.sum().backward()
    #报错

    错误信息为

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3]], which is output 0 of SigmoidBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

    如果不对out进行inplace操作而是对c进行inplace操作,结果是一样的,Out不能再进行反向传播了

    为了解决这种情况,就要对tensor的data操作,使其不被autograd记录
    重新得到一个out,把它的data部分给c

    c = out.data
    #tensor([0.7311, 0.8808, 0.9526])
    
    out
    #tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)

    这里可以看到,c中没有Out中有的grad_fn信息

    这回修改c的值,发现out的data值依然改了,但是执行out.sum().backward()不报错了

    detach_()

    def detach_(self):
        """Detaches the Variable from the graph that created it, making it a leaf.
        """
        self._grad_fn = None
        self.requires_grad = False

    做了两件事:1grad_fn设none2requires_grad设false

    它不会新生成一个Variable而是使用原来的variable

  • 相关阅读:
    Datax streamreader json测试样例
    dbeaver 连接 elasticsearch 记录
    灾害链开发记录资料汇总
    mxgraph
    drawio www.diagrams.net 画图应用程序开发过程资料汇总
    neo4j学习记录
    GraphVis 图可视化分析组件
    D3学习记录
    Kubernetes K8S之固定节点nodeName和nodeSelector调度详解
    记一次性能优化,单台4核8G机器支撑5万QPS
  • 原文地址:https://www.cnblogs.com/ljf-0/p/14015601.html
Copyright © 2011-2022 走看看