zoukankan      html  css  js  c++  java
  • Pytorch-损失函数-NLLLoss

    https://blog.csdn.net/weixin_40476348/article/details/94562240

    常用于多分类任务,NLLLoss 函数输入 input 之前,需要对 input 进行 log_softmax 处理,即将 input 转换成概率分布的形式,并且取对数,底数为 e
    class torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=-100, 
    					   reduce=None, reduction='mean')
    
    计算公式:loss(input, class) = -input[class]
    公式理解:input = [-0.1187, 0.2110, 0.7463],target = [1],那么 loss = -0.2110
    个人理解:感觉像是把 target 转换成 one-hot 编码,然后与 input 点乘得到的结果
    代码理解:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    

    torch.manual_seed(2019)

    output = torch.randn(1, 3) # 网络输出
    target = torch.ones(1, dtype=torch.long).random_(3) # 真实标签
    print(output)
    print(target)

    # 直接调用
    loss = F.nll_loss(output, target)
    print(loss)

    # 实例化类
    criterion = nn.NLLLoss()
    loss = criterion(output, target)
    print(loss)

    """
    tensor([[-0.1187, 0.2110, 0.7463]])
    tensor([1])
    tensor(-0.2110)
    tensor(-0.2110)
    """

    如果 input 维度为 M x N,那么 loss 默认取 M 个 loss 的平均值,reduction='none' 表示显示全部 loss
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    

    torch.manual_seed(2019)

    output = torch.randn(2, 3) # 网路输出
    target = torch.ones(2, dtype=torch.long).random_(3) # 真实标签
    print(output)
    print(target)

    # 直接调用
    loss = F.nll_loss(output, target)
    print(loss)

    # 实例化类
    criterion = nn.NLLLoss(reduction='none')
    loss = criterion(output, target)
    print(loss)

    """
    tensor([[-0.1187, 0.2110, 0.7463],
    [-0.6136, -0.1186, 1.5565]])
    tensor([2, 0])
    tensor(-0.0664)
    tensor([-0.7463, 0.6136])
    """

  • 相关阅读:
    分布图
    针对回归训练卷积神经网络
    polyfit 多项式曲线拟合matlab
    Re-run failed tests in testng
    URI 和 URL的区别
    十分钟理解Gradle
    移动App测试实战—专项测试(转)
    adb 常用命令
    MySQL基本操作
    Java注解
  • 原文地址:https://www.cnblogs.com/leebxo/p/11913939.html
Copyright © 2011-2022 走看看