zoukankan      html  css  js  c++  java
  • 损失函数

    1、torch.nn.CrossEntropyLoss()

    用于多分类问题

    loss_func=torch.nn.CrossEntropyLoss()

    loss=loss_func(input_data,input_target)

    其中input_data的shape一般是(batch_size,output_features),而input_target的shape是(batch_size)

    返回的loss是一个张量,但是只有一个数,代表的是计算结果的交叉商损失值

    交叉商的计算方法是:

    将输入的数据在最后一个维度上做softmax运算

    对softmax后的数据取log,注意softmax后所有的数值介于0和1之间,所以log后所有的数值全都是负数

    softmax_loged_data=torch.log(torch.nn.Softmax(dim=-1)(input_data))

    根据标签对应的数值去softmax_loged_data中索引出相应的数值并且去掉符号,

    将这batch_size个数值相加取平均后就是input_data与input_target的交叉商损失值

    2、torch.nn.MSELoss()

    用于回归问题

  • 相关阅读:
    P3275 [SCOI2011]糖果 题解
    hdu 2962 题解
    hdu 2167 题解
    hdu 2476 题解
    hdu 5418 题解
    2019.10.16&17小结
    poj 3061 题解(尺取法|二分
    poj 1852&3684 题解
    NOIP2017[提高组] 宝藏 题解
    一类经典问题的解法
  • 原文地址:https://www.cnblogs.com/liujianing/p/12357425.html
Copyright © 2011-2022 走看看