zoukankan      html  css  js  c++  java
  • PyTorch中一些损失函数的使用

    1. 均方误差(MSE)
      使用方法如下:
    import torch
    import torch.nn as nn
    a = torch.tensor([[1,2,3],[4,5,6]],dtype=torch.float)
    b = torch.tensor([[1,3,5],[6,5,4]],dtype=torch.float)
    print(a)
    print(b)
    loss = nn.MSELoss()
    l = loss(a,b)
    print(l)
    

    其中MSELoss是定义在torch.nn下的一个类,使用的时候如上述方法一样调用即可,实例化该类的时候一般只用到一个参数reduction,改参数的取值有none,mean,sum三种,默认情况下是mean,即求取平均值。

    1. 二元交叉熵损失函数(BCELoss)
      使用方法如下:
    a = torch.tensor([[0.5],[0.2],[0.4]],dtype=torch.float)
    target = torch.tensor([[1],[0],[1]],dtype=torch.float)
    loss = nn.BCELoss()
    print(loss(a,target))
    

    同样BCELoss是定义在torch.nn下的一个类,下面是官网对于该函数的一些解释:

    计算时输入的input和target是相同的形状,并且input的数值应该在0-1之间,也就是sigmoid后的数据。

    1. BCEWithLogitsLoss
      该损失函数相当于是将sigmoid函数和BCELoss组合在一起了,上面要求input的数值应该在0-1之间,也就是需要实现调用sigmoid函数,该损失函数内部会自动进行sigmoid计算,实现传入的数据不需要再进行此操作。

    2. CrossEntropyLoss
      该损失函数是将sofmax,log,NLLLoss三个函数组合在一起了,将输入的数据依次进行这三个操作。
      传入相应参数的形状如下所示:

      其中N是batch_size,C是种类数,需要注意的是target是一维的,最后计算的时候会将对应位置的数据取出来进行损失函数的计算。

  • 相关阅读:
    拥有最多糖果的孩子
    求1+2+…+n
    网络-中间代理
    Header中的Referer属性表示
    ios13.4post请求出现网错错误 network err
    10.8&10.10
    9.23&9.27
    9.16&9.19
    校内模拟赛划水报告(9.9,9.11)
    男人八题 划水题解
  • 原文地址:https://www.cnblogs.com/noob-l/p/15122229.html
Copyright © 2011-2022 走看看