zoukankan      html  css  js  c++  java
  • 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导

    torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现

    Varibale包含三个属性:

    • data:存储了Tensor,是本体的数据
    • grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
    • grad_fn:指向Function对象,用于反向传播的梯度计算之用

    data

    import torch as t
    from torch.autograd import Variable
    
    x = Variable(t.ones(2, 2), requires_grad = True)
    x  # 实际查询的是.data,是个Tensor
    

    实际上查询x和查询x.data返回结果一致,

    Variable containing:

     1 1

     1 1

     [torch.FloatTensor of size 2x2]

    梯度求解

    构建一个简单的方程:y = x[0,0] + x[0,1] + x[1,0] + x[1,1],Variable的运算结果也是Variable,但是,中间结果反向传播中不会被求导()

    这和TensorFlow不太一致,TensorFlow中中间运算果数据结构均是Tensor,

    y = x.sum()
    
    y
    """
      Variable containing:
       4
      [torch.FloatTensor of size 1]
    """
    

    可以查看目标函数的.grad_fn方法,它用来求梯度,

    y.grad_fn
    """
        <SumBackward0 at 0x18bcbfcdd30>
    """
    
    y.backward()  # 反向传播
    x.grad  # Variable的梯度保存在Variable.grad中
    """
      Variable containing:
       1  1
       1  1
      [torch.FloatTensor of size 2x2]
    """
    

    grad属性保存在Variable中,新的梯度下来会进行累加,可以看到再次求导后结果变成了2,

    y.backward()
    x.grad  # 可以看到变量梯度是累加的
    """
        Variable containing:
         2  2
         2  2
        [torch.FloatTensor of size 2x2]
    """
    

    所以要归零,

    x.grad.data.zero_()  # 归零梯度,注意,在torch中所有的inplace操作都是要带下划线的,虽然就没有.data.zero()方法
    
    """
     0  0
     0  0
    [torch.FloatTensor of size 2x2]
    """
    

    对比Variable和Tensor的接口,相差无两,

    Variable和Tensor的接口近乎一致,可以无缝切换
    
    x = Variable(t.ones(4, 5))
    
    y = t.cos(x)                         # 传入Variable
    x_tensor_cos = t.cos(x.data)  # 传入Tensor
    
    print(y)
    print(x_tensor_cos)
    
    """
    Variable containing:
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]
    
    
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]
    """
    
  • 相关阅读:
    第51天 [js] 字符串相连有哪些方式?哪种最好?为什么?
    第52天 [js] 什么是事件委托?它有什么好处?能简单的写一个例子吗?
    np.ndarray与Eigen::Matrix之间的互相转换
    C++向assert加入错误信息
    CeiT:训练更快的多层特征抽取ViT
    CoAtNet: 90.88% Paperwithcode榜单第一,层层深入考虑模型设计
    正式启用Danube 官方站点
    go 编译报错 package embed is not in GOROOT (/usr/local/go/src/embed)
    cloudreve兼容acme.sh脚本
    Git的交叉编译
  • 原文地址:https://www.cnblogs.com/hellcat/p/8439055.html
Copyright © 2011-2022 走看看