- NLLLoss 和 CrossEntropyLoss
在图片单标签分类时,输入m张图片,输出一个m*N的Tensor,其中N是分类个数。比如输入3张图片,分3类,最后的输出是一个3*3的Tensor
input = torch.tensor([[-0.1123, -0.6028, -0.0450],
[ 0.1596, 0.2215, -1.0176],
[-0.2359, -0.7898, 0.7097]])
第123行分别是第123张图片的结果,假设第123列分别是猫、狗和猪的分类得分。
first step: 对每一行使用Softmax,这样可以得到每张图片的概率分布。概率最大的为:1:猪;2:狗;3:猪。
sm = torch.nn.Softmax(dim=1)
sm(input)
tensor([[0.3729, 0.2283, 0.3988],
[0.4216, 0.4485, 0.1299],
[0.2410, 0.1385, 0.6205]])
second step: 对softmax结果取对数
torch.log(sm(input))
tensor([[-0.9865, -1.4770, -0.9192],
[-0.8637, -0.8019, -2.0409],
[-1.4229, -1.9767, -0.4773]])
Softmax后的数值都在0~1之间,所以log之后值域是负无穷到0。
NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,再去掉负号,再求均值。
假设我们现在Target是[0,2,1](第一张图片是猫,第二张是猪,第三张是狗)。第一行取第0个元素,第二行取第2个,第三行取第1个,去掉负号,结果是:[0.9865,2.0409,1.9767]。再求个均值,结果是:1.66
对比NLLLoss的结果
loss = torch.nn.NLLLoss()
loss(torch.log(sm(input)),target)
# 1.6681
CrossEntropyLoss 相当于上述步骤的组合,Softmax–Log–NLLLoss合并成一步
loss2 = torch.nn.CrossEntropyLoss()
loss2(input,target)
# 1.6681