zoukankan      html  css  js  c++  java
  • 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导

    torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现

    Varibale包含三个属性:

    • data:存储了Tensor,是本体的数据
    • grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
    • grad_fn:指向Function对象,用于反向传播的梯度计算之用

    data

    import torch as t
    from torch.autograd import Variable
    
    x = Variable(t.ones(2, 2), requires_grad = True)
    x  # 实际查询的是.data,是个Tensor
    

    实际上查询x和查询x.data返回结果一致,

    Variable containing:

     1 1

     1 1

     [torch.FloatTensor of size 2x2]

    梯度求解

    构建一个简单的方程:y = x[0,0] + x[0,1] + x[1,0] + x[1,1],Variable的运算结果也是Variable,但是,中间结果反向传播中不会被求导()

    这和TensorFlow不太一致,TensorFlow中中间运算果数据结构均是Tensor,

    y = x.sum()
    
    y
    """
      Variable containing:
       4
      [torch.FloatTensor of size 1]
    """
    

    可以查看目标函数的.grad_fn方法,它用来求梯度,

    y.grad_fn
    """
        <SumBackward0 at 0x18bcbfcdd30>
    """
    
    y.backward()  # 反向传播
    x.grad  # Variable的梯度保存在Variable.grad中
    """
      Variable containing:
       1  1
       1  1
      [torch.FloatTensor of size 2x2]
    """
    

    grad属性保存在Variable中,新的梯度下来会进行累加,可以看到再次求导后结果变成了2,

    y.backward()
    x.grad  # 可以看到变量梯度是累加的
    """
        Variable containing:
         2  2
         2  2
        [torch.FloatTensor of size 2x2]
    """
    

    所以要归零,

    x.grad.data.zero_()  # 归零梯度,注意,在torch中所有的inplace操作都是要带下划线的,虽然就没有.data.zero()方法
    
    """
     0  0
     0  0
    [torch.FloatTensor of size 2x2]
    """
    

    对比Variable和Tensor的接口,相差无两,

    Variable和Tensor的接口近乎一致,可以无缝切换
    
    x = Variable(t.ones(4, 5))
    
    y = t.cos(x)                         # 传入Variable
    x_tensor_cos = t.cos(x.data)  # 传入Tensor
    
    print(y)
    print(x_tensor_cos)
    
    """
    Variable containing:
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]
    
    
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]
    """
    
  • 相关阅读:
    C# 从服务器下载文件
    不能使用联机NuGet 程序包
    NPOI之Excel——合并单元格、设置样式、输入公式
    jquery hover事件中 fadeIn和fadeOut 效果不能及时停止
    UVA 10519 !! Really Strange !!
    UVA 10359 Tiling
    UVA 10940 Throwing cards away II
    UVA 10079 Pizze Cutting
    UVA 763 Fibinary Numbers
    UVA 10229 Modular Fibonacci
  • 原文地址:https://www.cnblogs.com/hellcat/p/8439055.html
Copyright © 2011-2022 走看看