zoukankan      html  css  js  c++  java
  • sigmoid_cross_entropy_with_logits

    sigmoid_cross_entropy_with_logits

    觉得有用的话,欢迎一起讨论相互学习~


    我的微博我的github我的B站

    函数定义

    def sigmoid_cross_entropy_with_logits(_sentinel=None,  # pylint: disable=invalid-name
                                          labels=None, logits=None,
                                          name=None):
    

    函数意义

    • 这个函数的作用是计算经sigmoid 函数激活之后的交叉熵。
    • 为了描述简洁,我们规定 x = logits,z = targets,那么 Logistic 损失值为:

    [x - x * z + log( 1 + exp(-x) ) ]

    • 对于x<0的情况,为了执行的稳定,使用计算式:

    [-x * z + log(1 + exp(x)) ]

    • 为了确保计算稳定,避免溢出,真实的计算实现如下:

    [max(x, 0) - x * z + log(1 + exp(-abs(x)) ) ]

    • logits 和 targets 必须有相同的数据类型和数据维度。
    • 它适用于每个类别相互独立但互不排斥的情况,在一张图片中,同时包含多个分类目标(大象和狗),那么就可以使用这个函数。

    例子

    import numpy as np
    import tensorflow as tf
    
    input_data = tf.Variable(np.random.rand(1, 3), dtype=tf.float32)
    # np.random.rand()传入一个shape,返回一个在[0,1)区间符合均匀分布的array
    
    output = tf.nn.sigmoid_cross_entropy_with_logits(logits=input_data, labels=[[1.0, 0.0, 0.0]])
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        print(sess.run(output))
        # [[ 0.5583781   1.06925142  1.08170223]]
    

    输入与输出

    输入

    • _sentinel: 一般情况下不怎么使用的参数,可以直接保持默认使其为None
    • logits: 一个Tensor。数据类型是以下之一:float32或者float64。
    • targets: 一个Tensor。数据类型和数据维度都和 logits 相同。
    • name: 为这个操作取个名字。
      输出
    • 一个 Tensor ,数据维度和 logits 相同。

    推导过程

    x = logits, z = labels.

    • logistic loss 计算式为:
    • 其中交叉熵(cross entripy)基本函数式
          z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
          = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
          = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
          = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
          = (1 - z) * x + log(1 + exp(-x))
          = x - x * z + log(1 + exp(-x))
    

    对于x<0时,为了避免计算exp(-x)时溢出,我们使用以下这种形式表示

            x - x * z + log(1 + exp(-x))
          = log(exp(x)) - x * z + log(1 + exp(-x))
          = - x * z + log(1 + exp(x))
    

    综合x>0和x<0的情况,我们使用以下函数式
    $$max(x, 0) - x * z + log(1 + exp(-abs(x)))$$

    注意logits和labels必须具有相同的type和shape

  • 相关阅读:
    STL_算法_查找算法(lower_bound、upper_bound、equal_range)
    Kafka深度解析
    hdoj 1027 Ignatius and the Princess II 【逆康托展开】
    代码二次封装-xUtils(android)
    C++11新特性应用--介绍几个新增的便利算法(用于排序的几个算法)
    谨防串行的状态报告会
    hadoop中NameNode、DataNode和Client三者之间协作关系及通信方式介绍
    Kindeditor JS 取值问题以及上传图片后回调等
    MySQL 8.0.11 报错[ERROR] [MY-011087] Different lower_case_table_names settings for server ('1')
    爱的链条
  • 原文地址:https://www.cnblogs.com/cloud-ken/p/7435421.html
Copyright © 2011-2022 走看看