zoukankan      html  css  js  c++  java
  • Focal Loss tensorflow 实现

        def focal_loss(pred, y, alpha=0.25, gamma=2):
            r"""Compute focal loss for predictions.
                Multi-labels Focal loss formula:
                    FL = -alpha * (z-p)^gamma * log(p) -(1-alpha) * p^gamma * log(1-p)
                         ,which alpha = 0.25, gamma = 2, p = sigmoid(x), z = target_tensor.
            Args:
             pred: A float tensor of shape [batch_size, num_anchors,
                num_classes] representing the predicted logits for each class
             y: A float tensor of shape [batch_size, num_anchors,
                num_classes] representing one-hot encoded classification targets
             alpha: A scalar tensor for focal loss alpha hyper-parameter
             gamma: A scalar tensor for focal loss gamma hyper-parameter
            Returns:
                loss: A (scalar) tensor representing the value of the loss function
            """
            zeros = tf.zeros_like(pred, dtype=pred.dtype)
    
            # For positive prediction, only need consider front part loss, back part is 0;
            # target_tensor > zeros <=> z=1, so positive coefficient = z - p.
            pos_p_sub = tf.where(y > zeros, y - pred, zeros) # positive sample 寻找正样本,并进行填充
    
            # For negative prediction, only need consider back part loss, front part is 0;
            # target_tensor > zeros <=> z=1, so negative coefficient = 0.
            neg_p_sub = tf.where(y > zeros, zeros, pred) # negative sample 寻找负样本,并进行填充
            per_entry_cross_ent = - alpha * (pos_p_sub ** gamma) * tf.log(tf.clip_by_value(pred, 1e-8, 1.0)) 
                                  - (1 - alpha) * (neg_p_sub ** gamma) * tf.log(tf.clip_by_value(1.0 - pred, 1e-8, 1.0))
    
            return tf.reduce_sum(per_entry_cross_ent)
  • 相关阅读:
    python之os模块
    python之字符串
    python之爬虫(beautifulsoup)
    python之常见算法
    python之装饰器(类装饰器,函数装饰器)
    python之mock使用,基于unittest
    python之定时器
    python基础语法随记
    redis基础
    移动端页面开发(二)
  • 原文地址:https://www.cnblogs.com/callyblog/p/11168555.html
Copyright © 2011-2022 走看看