zoukankan      html  css  js  c++  java
  • torch 中的损失函数

    1. NLLLoss 和 CrossEntropyLoss
      在图片单标签分类时,输入m张图片,输出一个m*N的Tensor,其中N是分类个数。比如输入3张图片,分3类,最后的输出是一个3*3的Tensor
    input = torch.tensor([[-0.1123, -0.6028, -0.0450],
                  [ 0.1596,  0.2215, -1.0176],
                  [-0.2359, -0.7898,  0.7097]])
    

    第123行分别是第123张图片的结果,假设第123列分别是猫、狗和猪的分类得分。
    first step: 对每一行使用Softmax,这样可以得到每张图片的概率分布。概率最大的为:1:猪;2:狗;3:猪。

    sm = torch.nn.Softmax(dim=1)
    sm(input)
    tensor([[0.3729, 0.2283, 0.3988],
            [0.4216, 0.4485, 0.1299],
            [0.2410, 0.1385, 0.6205]])
    

    second step: 对softmax结果取对数

    torch.log(sm(input))
    tensor([[-0.9865, -1.4770, -0.9192],
            [-0.8637, -0.8019, -2.0409],
            [-1.4229, -1.9767, -0.4773]])
    

    Softmax后的数值都在0~1之间,所以log之后值域是负无穷到0。
    NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,再去掉负号,再求均值。
    假设我们现在Target是[0,2,1](第一张图片是猫,第二张是猪,第三张是狗)。第一行取第0个元素,第二行取第2个,第三行取第1个,去掉负号,结果是:[0.9865,2.0409,1.9767]。再求个均值,结果是:1.66
    对比NLLLoss的结果

    loss = torch.nn.NLLLoss()
    loss(torch.log(sm(input)),target)
    # 1.6681
    

    CrossEntropyLoss 相当于上述步骤的组合,Softmax–Log–NLLLoss合并成一步

    loss2 = torch.nn.CrossEntropyLoss()
    loss2(input,target)
    # 1.6681
    
  • 相关阅读:
    sqlsever中生成GUID的方法
    部署项目到服务器
    读后感
    第二次作业
    课堂作业
    第一次作业 开发环境配置介绍
    第二次结对作业
    代码审查
    最大连续子数组和
    单元测试
  • 原文地址:https://www.cnblogs.com/leimu/p/13346372.html
Copyright © 2011-2022 走看看