zoukankan      html  css  js  c++  java
  • torch.nn.BCELoss用法

    1. 定义
      数学公式为 Loss = -w * [p * log(q) + (1-p) * log(1-q)] ,其中p、q分别为理论标签、实际预测值,w为权重。这里的log对应数学上的ln。

      PyTorch对应函数为:
        torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction=‘mean’)
      计算目标值和预测值之间的二进制交叉熵损失函数。

      有四个可选参数:weight、size_average、reduce、reduction

    1. weight必须和target的shape一致,默认为none。定义BCELoss的时候指定即可。
    2. 默认情况下 nn.BCELoss(),reduce = True,size_average = True。
    3. 如果reduce为False,size_average不起作用,返回向量形式的loss。
    4. 如果reduce为True,size_average为True,返回loss的均值,即loss.mean()。
    5. 如果reduce为True,size_average为False,返回loss的和,即loss.sum()。
    6. 如果reduction = ‘none’,直接返回向量形式的 loss。
    7. 如果reduction = ‘sum’,返回loss之和。
    8. 如果reduction = ''elementwise_mean,返回loss的平均值。
    9. 如果reduction = ''mean,返回loss的平均值

    2. 验证代码

    1>

    import torch
    import torch.nn as nn
    
    m = nn.Sigmoid()
    
    loss = nn.BCELoss(size_average=False, reduce=False)
    input = torch.randn(3, requires_grad=True)
    target = torch.empty(3).random_(2)
    lossinput = m(input)
    output = loss(lossinput, target)
    
    print("输入值:")
    print(lossinput)
    print("输出的目标值:")
    print(target)
    print("计算loss的结果:")
    print(output)

     2>

    import torch
    import torch.nn as nn
    
    m = nn.Sigmoid()
    
    loss = nn.BCELoss(size_average=True, reduce=False)
    input = torch.randn(3, requires_grad=True)
    target = torch.empty(3).random_(2)
    lossinput = m(input)
    output = loss(lossinput, target)
    
    print("输入值:")
    print(lossinput)
    print("输出的目标值:")
    print(target)
    print("计算loss的结果:")
    print(output)

     3>

    import torch
    import torch.nn as nn
    
    m = nn.Sigmoid()
    
    loss = nn.BCELoss(size_average=True, reduce=True)
    input = torch.randn(3, requires_grad=True)
    target = torch.empty(3).random_(2)
    lossinput = m(input)
    output = loss(lossinput, target)
    
    print("输入值:")
    print(lossinput)
    print("输出的目标值:")
    print(target)
    print("计算loss的结果:")
    print(output)

     4>

    import torch
    import torch.nn as nn
    
    m = nn.Sigmoid()
    
    loss = nn.BCELoss(size_average=False, reduce=True)
    input = torch.randn(3, requires_grad=True)
    target = torch.empty(3).random_(2)
    lossinput = m(input)
    output = loss(lossinput, target)
    
    print("输入值:")
    print(lossinput)
    print("输出的目标值:")
    print(target)
    print("计算loss的结果:")
    print(output)

     5>

    import torch
    import torch.nn as nn
    
    m = nn.Sigmoid()
    
    loss = nn.BCELoss(reduction = 'none')
    input = torch.randn(3, requires_grad=True)
    target = torch.empty(3).random_(2)
    lossinput = m(input)
    output = loss(lossinput, target)
    
    print("输入值:")
    print(lossinput)
    print("输出的目标值:")
    print(target)
    print("计算loss的结果:")
    print(output)

     6>

    import torch
    import torch.nn as nn
    
    m = nn.Sigmoid()
    weights=torch.randn(3)
    
    loss = nn.BCELoss(weight=weights,size_average=False, reduce=False)
    input = torch.randn(3, requires_grad=True)
    target = torch.empty(3).random_(2)
    lossinput = m(input)
    output = loss(lossinput, target)
    
    print("输入值:")
    print(lossinput)
    print("输出的目标值:")
    print(target)
    print("权重值")
    print(weights)
    print("计算loss的结果:")
    print(output)

     

    2. 验证代码
    1>
    import torchimport torch.nn as nn
    m = nn.Sigmoid()
    loss = nn.BCELoss(size_average=False, reduce=False)input = torch.randn(3, requires_grad=True)target = torch.empty(3).random_(2)lossinput = m(input)output = loss(lossinput, target)
    print("输入值:")print(lossinput)print("输出的目标值:")print(target)print("计算loss的结果:")print(output)1234567891011121314151617
    2>
    import torchimport torch.nn as nn
    m = nn.Sigmoid()
    loss = nn.BCELoss(size_average=True, reduce=False)input = torch.randn(3, requires_grad=True)target = torch.empty(3).random_(2)lossinput = m(input)output = loss(lossinput, target)
    print("输入值:")print(lossinput)print("输出的目标值:")print(target)print("计算loss的结果:")print(output)1234567891011121314151617
    3>
    import torchimport torch.nn as nn
    m = nn.Sigmoid()
    loss = nn.BCELoss(size_average=True, reduce=True)input = torch.randn(3, requires_grad=True)target = torch.empty(3).random_(2)lossinput = m(input)output = loss(lossinput, target)
    print("输入值:")print(lossinput)print("输出的目标值:")print(target)print("计算loss的结果:")print(output)1234567891011121314151617
    4>
    import torchimport torch.nn as nn
    m = nn.Sigmoid()
    loss = nn.BCELoss(size_average=False, reduce=True)input = torch.randn(3, requires_grad=True)target = torch.empty(3).random_(2)lossinput = m(input)output = loss(lossinput, target)
    print("输入值:")print(lossinput)print("输出的目标值:")print(target)print("计算loss的结果:")print(output)1234567891011121314151617
    5>
    import torchimport torch.nn as nn
    m = nn.Sigmoid()
    loss = nn.BCELoss(reduction = 'none')input = torch.randn(3, requires_grad=True)target = torch.empty(3).random_(2)lossinput = m(input)output = loss(lossinput, target)
    print("输入值:")print(lossinput)print("输出的目标值:")print(target)print("计算loss的结果:")print(output)1234567891011121314151617
    6>
    import torchimport torch.nn as nn
    m = nn.Sigmoid()weights=torch.randn(3)
    loss = nn.BCELoss(weight=weights,size_average=False, reduce=False)input = torch.randn(3, requires_grad=True)target = torch.empty(3).random_(2)lossinput = m(input)output = loss(lossinput, target)
    print("输入值:")print(lossinput)print("输出的目标值:")print(target)print("权重值")print(weights)print("计算loss的结果:")print(output)1234567891011121314151617181920
    ————————————————版权声明:本文为CSDN博主「qq_29631521」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。原文链接:https://blog.csdn.net/qq_29631521/article/details/104907401

    因上求缘,果上努力~~~~ 作者:每天卷学习,转载请注明原文链接:https://www.cnblogs.com/BlairGrowing/p/15510527.html

  • 相关阅读:
    vue生命周期详细解析
    Chrome浏览器中onunload有时候没反应怎么办
    JavaScript中<button>与<input type="button"..的区别
    java泛型
    hashCode与equals
    HttpClient HttpServlet HttpUrlConnection
    think in java 笔记
    红黑树
    AC自动机
    并查集
  • 原文地址:https://www.cnblogs.com/BlairGrowing/p/15510527.html
Copyright © 2011-2022 走看看