zoukankan      html  css  js  c++  java
  • PyTorch中一些损失函数的使用

    1. 均方误差(MSE)
      使用方法如下:
    import torch
    import torch.nn as nn
    a = torch.tensor([[1,2,3],[4,5,6]],dtype=torch.float)
    b = torch.tensor([[1,3,5],[6,5,4]],dtype=torch.float)
    print(a)
    print(b)
    loss = nn.MSELoss()
    l = loss(a,b)
    print(l)
    

    其中MSELoss是定义在torch.nn下的一个类,使用的时候如上述方法一样调用即可,实例化该类的时候一般只用到一个参数reduction,改参数的取值有none,mean,sum三种,默认情况下是mean,即求取平均值。

    1. 二元交叉熵损失函数(BCELoss)
      使用方法如下:
    a = torch.tensor([[0.5],[0.2],[0.4]],dtype=torch.float)
    target = torch.tensor([[1],[0],[1]],dtype=torch.float)
    loss = nn.BCELoss()
    print(loss(a,target))
    

    同样BCELoss是定义在torch.nn下的一个类,下面是官网对于该函数的一些解释:

    计算时输入的input和target是相同的形状,并且input的数值应该在0-1之间,也就是sigmoid后的数据。

    1. BCEWithLogitsLoss
      该损失函数相当于是将sigmoid函数和BCELoss组合在一起了,上面要求input的数值应该在0-1之间,也就是需要实现调用sigmoid函数,该损失函数内部会自动进行sigmoid计算,实现传入的数据不需要再进行此操作。

    2. CrossEntropyLoss
      该损失函数是将sofmax,log,NLLLoss三个函数组合在一起了,将输入的数据依次进行这三个操作。
      传入相应参数的形状如下所示:

      其中N是batch_size,C是种类数,需要注意的是target是一维的,最后计算的时候会将对应位置的数据取出来进行损失函数的计算。

  • 相关阅读:
    洛谷P3003 [USACO10DEC]苹果交货Apple Delivery
    洛谷P1576 最小花费
    洛谷P1821 [USACO07FEB]银牛派对Silver Cow Party
    洛谷P1948 [USACO08JAN]电话线Telephone Lines
    洛谷P3371【模板】单源最短路径
    洛谷P2384最短路
    FirstOfAll
    Proxy模式:管理第三方API
    Abstract Server模式,Adapter模式和Bridge模式
    Observer模式
  • 原文地址:https://www.cnblogs.com/noob-l/p/15122229.html
Copyright © 2011-2022 走看看