1.MSE(均方差)梯度
(1)均方差MSE
(2)MSE求梯度
【注】例如网络形式为线性感知机:ƒ(x)=w*x+b这里只是举例,具体用什么样的函数需要根据实际的网络结构。
对w求导则是:Δƒw(w)/Δw
对b求导则是:Δƒb(b)/Δb
(3)均方差在pytorch中如何求梯度
(3.1.1)torch.autograd.grad(loss,[w1,w2...........])
【注】pytorch中mse_loss的自动微分:
F.mse_loss(label,pred) pred的为线性感知机中的w*x+b,label为x。
torch.autograd.grad(mse,para)para为线性感知机中的w和b参数。其中第一个参数必须为维度为1长度为1的tensor。
【注】只有浮点数型数据才能计算梯度,故上图中会出现23和24行下面的错误。requires_grad_()可以对tensor类型的数据进行更新,使其可以进行梯度运算。
(3.1.2)loss.backward()
(3.1.3)pytorch中损失函数求梯度的两种方法总结
[注]两种方式返回值的形式不同:
第一种为【w1 grad,w2 grad】
第二种为w1.grad或者w2.grad等。
[注]可以对tensor类型的数据进行.norm查看tensor的norm,也可以对梯度信息进行.norm。