zoukankan      html  css  js  c++  java
  • pytorch中的Variable() Learner

    函数简介

    torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现(tensor变成variable之后才能进行反向传播求梯度?用变量.backward()进行反向传播之后,var.grad中保存了var的梯度)

    x = Variable(tensor, requires_grad = True)

    Varibale包含三个属性:

    • data:存储了Tensor,是本体的数据
    • grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
    • grad_fn:指向Function对象,用于反向传播的梯度计算之用

    用法:

    import torch
    from torch.autograd import Variable
     
    x = Variable(torch.one(2,2), requires_grad = True)
    print(x)#其实查询的是x.data,是个tensor

    举个例子求梯度:

    构建一个简单的方程:y = x[0,0] + x[0,1] + x[1,0] + x[1,1],Variable的运算结果也是Variable,但是,中间结果反向传播中不会被求导()

    这和TensorFlow不太一致,TensorFlow中中间运算果数据结构均是Tensor

    y = x.sum()
     
    y
    """
      Variable containing:
       4
      [torch.FloatTensor of size 1]
    """
     
     
    #可以查看目标函数的.grad_fn方法,它用来求梯度
    y.grad_fn
    """
        <SumBackward0 at 0x18bcbfcdd30>
    """
     
    y.backward()  # 反向传播
    x.grad  # Variable的梯度保存在Variable.grad中
    """
      Variable containing:
       1  1
       1  1
      [torch.FloatTensor of size 2x2]
    """
     
     
    #grad属性保存在Variable中,新的梯度下来会进行累加,可以看到再次求导后结果变成了2,
    y.backward()
    x.grad  # 可以看到变量梯度是累加的
    """
        Variable containing:
         2  2
         2  2
        [torch.FloatTensor of size 2x2]
    """
     
    #所以要归零
    x.grad.data.zero_()  # 归零梯度,注意,在torch中所有的inplace操作都是要带下划线的,虽然就没有.data.zero()方法
     
    """
     0  0
     0  0
    [torch.FloatTensor of size 2x2]
    """
     
     
    #对比Variable和Tensor的接口,相差无两
    x = Variable(torch.ones(4, 5))
     
    y = torch.cos(x)                         # 传入Variable
    x_tensor_cos = torch.cos(x.data)  # 传入Tensor
     
    print(y)
    print(x_tensor_cos)
     
    """
    Variable containing:
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]
     
     
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
     0.5403  0.5403  0.5403  0.5403  0.5403
    [torch.FloatTensor of size 4x5]

    参考:

    https://blog.csdn.net/u012370185/article/details/94391428

  • 相关阅读:
    JavaScript学习(二)
    javaScript学习(一)
    CSS学习(一)
    HTML学习(一)
    ES之node机器配置elasticsearch.yml
    ES之master机器配置elasticsearch.yml
    jenkins--前端依赖之 node
    jenkins--邮件插件配置
    JsonPath提取表达式
    this关键字的作用
  • 原文地址:https://www.cnblogs.com/BlairGrowing/p/15709599.html
Copyright © 2011-2022 走看看