1. 让张量使用Variable类型,如下所示
1 from torch.autograd import Variable 2 3 inp = torch.zeros(2, 3) 4 inp = Variable(inp).type(torch.LongTensor) 5 print(inp)
Variable类型包装了Tensor类型,并提供了backward()接口
使用Variable类型的好处是,可以按照论文公式来直接使用,并在做张量运算之后,使用继承的backward()直接进行反向传播
2. 自定义类继承nn.Module
1 class CustomMSELoss(nn.Module): 2 def __init__(self): 3 super().__init__() 4 5 def forward(self, x, y): 6 return torch.mean(torch.pow((x - y), 2))
这种方法结构化程度高,在开发给用户使用时,由于不知道用户的Tensor是否是Variable类型,采用该方法可以减少问题。