torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现(tensor变成variable之后才能进行反向传播求梯度?用变量.backward()进行反向传播之后,var.grad中保存了var的梯度)
x = Variable(tensor, requires_grad = True)
Varibale包含三个属性:
- data:存储了Tensor,是本体的数据
- grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
- grad_fn:指向Function对象,用于反向传播的梯度计算之用
用法:
-
import torch
-
from torch.autograd import Variable
-
x = Variable(torch.one(2,2), requires_grad = True)
-
print(x)#其实查询的是x.data,是个tenso