zoukankan      html  css  js  c++  java
  • Focal Loss 的Pytorch 实现以及实验

    Focal Loss 的Pytorch 实现以及实验

    Focal loss 是 文章 Focal Loss for Dense Object Detection 中提出对简单样本的进行decay的一种损失函数。是对标准的Cross Entropy Loss 的一种改进。 F L对于简单样本(p比较大)回应较小的loss。

    如论文中的图1, 在p=0.6时, 标准的CE然后又较大的loss, 但是对于FL就有相对较小的loss回应。这样就是对简单样本的一种decay。其中alpha 是对每个类别在训练数据中的频率有关, 但是下面的实现我们是基于alpha=1进行实验的。

    标准的Cross Entropy 为:

    [公式]

    Focal Loss 为:

    [公式]

    [公式]

    其中 [公式]

    以上公式为下面实现代码的基础。

     

    采用基于pytorch 的yolo2 在VOC的上的实验结果如下:

     

    在单纯的替换了CrossEntropyLoss之后就有1个点左右的提升。效果还是比较显著的。本实验中采用的是darknet19, 要是采用更大的网络就可能会有更好的性能提升。这个实验结果已经能很好的说明的Focal Loss 的对于检测的价值了。

     

    一点没做的但是可能会提升性能:

    1. 采用soft - gamma: 在训练的过程中阶段性的增大gamma 可能会有更好的性能提升

     

     

    本文实验中采用的Focal Loss 代码如下。

    关于Focal Loss 的数学推倒在文章:Focal Loss 的前向与后向公式推导

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.autograd import Variable
    
    class FocalLoss(nn.Module):
        r"""
            This criterion is a implemenation of Focal Loss, which is proposed in 
            Focal Loss for Dense Object Detection.
    
                Loss(x, class) = - alpha (1-softmax(x)[class])^gamma log(softmax(x)[class])
    
            The losses are averaged across observations for each minibatch.
    
            Args:
                alpha(1D Tensor, Variable) : the scalar factor for this criterion
                gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                       putting more focus on hard, misclassified examples
                size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                    However, if the field size_average is set to False, the losses are
                                    instead summed for each minibatch.
    
    
        """
        def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
            super(FocalLoss, self).__init__()
            if alpha is None:
                self.alpha = Variable(torch.ones(class_num, 1))
            else:
                if isinstance(alpha, Variable):
                    self.alpha = alpha
                else:
                    self.alpha = Variable(alpha)
            self.gamma = gamma
            self.class_num = class_num
            self.size_average = size_average
    
        def forward(self, inputs, targets):
            N = inputs.size(0)
            C = inputs.size(1)
            P = F.softmax(inputs)
    
            class_mask = inputs.data.new(N, C).fill_(0)
            class_mask = Variable(class_mask)
            ids = targets.view(-1, 1)
            class_mask.scatter_(1, ids.data, 1.)
            #print(class_mask)
    
    
            if inputs.is_cuda and not self.alpha.is_cuda:
                self.alpha = self.alpha.cuda()
            alpha = self.alpha[ids.data.view(-1)]
    
            probs = (P*class_mask).sum(1).view(-1,1)
    
            log_p = probs.log()
            #print('probs size= {}'.format(probs.size()))
            #print(probs)
    
            batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
            #print('-----bacth_loss------')
            #print(batch_loss)
    
    
            if self.size_average:
                loss = batch_loss.mean()
            else:
                loss = batch_loss.sum()
            return loss
  • 相关阅读:
    HTML 基础 元素 标签
    HTML5 元素介绍
    网站程序 模板下载 下载 ftp
    域名解析和空间绑定
    如何选择云虚拟主机操作系统?
    网站备案查询
    响应式网站01
    项目中使用百度统计和友盟统计
    项目中使用http referer,为了盗取图片资源
    vue-awesome-swiper中的数据异步加载
  • 原文地址:https://www.cnblogs.com/yumoye/p/11253049.html
Copyright © 2011-2022 走看看