zoukankan      html  css  js  c++  java
  • Pytorch-区分nn.BCELoss()、nn.BCEWithLogitsLoss()和nn.CrossEntropyLoss() 的用法

    详细理论部分可参考https://www.cnblogs.com/wanghui-garcia/p/10862733.html

    BCELoss()和BCEWithLogitsLoss()的输出logits和目标labels的形状相同。

     1 import torch 
     2 import torch.nn as nn
     3 import torch.nn.functional as F
     4 
     5 m = nn.Sigmoid()
     6 
     7 loss_f1 = nn.BCELoss()  
     8 loss_f2 = nn.BCEWithLogitsLoss() 
     9 loss_f3 = nn.CrossEntropyLoss() 
    10 
    11 logits = torch.randn(3, 2)
    12 labels = torch.FloatTensor([[0, 1], [1, 0], [1, 0]])
    13 
    14 print(loss_f1(m(logits), labels))      #tensor(0.9314),注意logits先被激活函数作用 
    15 print(loss_f2(logits, labels))         #tensor(0.9314)
    16 
    17 label2 = torch.LongTensor([1, 0, 0])
    18 print(loss_f3(logits, label2))         #tensor(1.2842)

     如果label2也想变成labels,然后通过BCELoss进行计算的话,可以先转变成独热编码的形式:

    1 encode = F.one_hot(label2, num_classes = 2)               #encode的值和labels一样,但是类型是LongTensor
    2 print(loss_f1(m(logits), encode.type(torch.float32)))     #tensor(0.9314)
  • 相关阅读:
    抚琴弹唱东流水
    借点阳光给你
    日月成双行影单
    一夜飘雪入冬来
    悼念钱学森
    我的青春谁作主
    重游望江楼有感
    雪后暖阳
    满城尽添黄金装
    敢叫岁月不冬天
  • 原文地址:https://www.cnblogs.com/cxq1126/p/13794877.html
Copyright © 2011-2022 走看看