zoukankan      html  css  js  c++  java
  • gradients的一些注意点

    Each variable has a [.grad_fn] attribute that references a Function that has created the Variable(except for Variables created by the user their grad_fn is None).

    out.backward()=out.backward(torch.Tensor([1.0]))

     1 import torch
     2 from torch.autograd import Variable
     3 
     4 # part 1
     5 x = Variable(torch.ones(2, 2), requires_grad=True)
     6 y = x + 2
     7 print(y.grad_fn)    # <torch.autograd.function.AddConstantBackward object at 0x000001D8A156E048>
     8 print(y.grad)       # None
     9 z = y * y * 3
    10 out = z.mean()
    11 out.backward()
    12 print(out.grad)     # None
    13 print(y.grad)       # None
    14 print(x.grad)       # d(out)/dx
    15 '''
    16 Variable containing:
    17  4.5000  4.5000
    18  4.5000  4.5000
    19 [torch.FloatTensor of size 2x2]
    20 '''
    21 print(x.grad_fn)    # None
    22 print(x.grad_output)
    23 
    24 
    25 # part 2
    26 x = torch.randn(3)
    27 x = Variable(x, requires_grad=True)
    28 y = x * 2
    29 # print(type(y))          # <class 'torch.autograd.variable.Variable'>
    30 # print(type(y.data))     # <class 'torch.FloatTensor'>
    31 # print(y.data.norm())    # 4.076032686856067
    32 while y.data.norm() < 1000:
    33     y = y * 2
    34 
    35 # print(x.grad)           # None
    36 gradients = torch.FloatTensor([0.1, 1.0, 0.0001])
    37 # print(y)  # Variable containing: 377.3516 997.8206 11.2558 [torch.FloatTensor of size 3]
    38 y.backward(gradients)
    39 # y.backward()            # y.backward()=y.backward(torch.Tensor([1.0]))
    40 # RuntimeError: grad can be implicitly created only for scalar outputs
    41 
    42 print(x.grad)
    43 # print(x.grad_fn)          # None
    44 # print(x.grad_output)      # AttributeError: 'Variable' object has no attribute 'grad_output'
  • 相关阅读:
    「日常报错」Response to preflight request doesn't pass access control check: It does not have HTTP ok status.
    LeetCode1026. 节点与其祖先之间的最大差值
    Erlang TCP 实例
    「笔记」Systemd 的基础
    折腾日记「乱七八糟的过程」
    LeetCode15.三数之和
    Sql Paging
    行变列
    SQL JOINS
    DBML存储过程返回值
  • 原文地址:https://www.cnblogs.com/Joyce-song94/p/7481511.html
Copyright © 2011-2022 走看看