zoukankan      html  css  js  c++  java
  • 『TensorFlow』分类问题与两种交叉熵

    关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类、多分类等)任务目标,可以参考文章keras中两种交叉熵损失函数的探讨,其结合keras的API讨论了两者的计算原理和应用原理。

    本文主要是介绍TF中的接口调用方式。

    一、二分类交叉熵

    对应的是网络输出单个节点,这个节点将被sigmoid处理,使用阈值分类为0或者1的问题。此类问题logits和labels必须具有相同的type和shape

    原理介绍

    x = logits, z = labels.
    logistic loss 计算式为: 其中交叉熵(cross entripy)基本函数式

           z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
        = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
        = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
        = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
        = (1 - z) * x + log(1 + exp(-x))
        = x - x * z + log(1 + exp(-x))

    对于x<0时,为了避免计算exp(-x)时溢出,我们使用以下这种形式表示

           x - x * z + log(1 + exp(-x))
        = log(exp(x)) - x * z + log(1 + exp(-x))
        = - x * z + log(1 + exp(x))

    综合x>0和x<0的情况,并防止溢出我们使用如下公式,

    max(x, 0) - x *z + log(1 + exp(-abs(x)))

    接口介绍

    import numpy as np
    import tensorflow as tf
    
    input_data = tf.Variable(np.random.rand(1, 3), dtype=tf.float32)
    # np.random.rand()传入一个shape,返回一个在[0,1)区间符合均匀分布的array
    
    output = tf.nn.sigmoid_cross_entropy_with_logits(logits=input_data, labels=[[1.0, 0.0, 0.0]])
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        print(sess.run(output))
        # [[ 0.5583781   1.06925142  1.08170223]]
    

    二、多分类交叉熵

    对应的是网络输出多个节点,每个节点表示1个class的得分,使用Softmax最终处理的分类问题。

    原理介绍

    cross_entropy = -tf.reduce_mean(y * tf.log(tf.clip_by_value(y_pre, 1e-10, 1.0))
    

    调用一下:

    import tensorflow as tf
     
    input_data = tf.Variable([[0.2, 0.1, 0.9], [0.3, 0.4, 0.6]], dtype=tf.float32)
    labels=tf.constant([[1,0,0], [0,1,0]], dtype=tf.float32)
    
    cross_entropy = -tf.reduce_mean(labels * tf.log(tf.clip_by_value(input_data, 1e-10, 1.0)))
                                    
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        print(sess.run(output))
    

    接口介绍

    softmax之后,计算输出层全部节点各自的交叉熵(输出向量而非标量)

    cross_entropy_mean = tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.argmax(labels,1), logits=logits), name='cross_entropy')
     
    cross_entropy_mean = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(
            logits=logits, labels=labels), name='cross_entropy')
    

    tf.nn.softmax_cross_entropy_with_logits()

    函数的参数label是稀疏表示的,比如表示一个3分类的一个样本的标签,稀疏表示的形式为[0,0,1]这个表示这个样本为第3个分类,而非稀疏表示就表示为2,同理[0,1,0]就表示样本属于第2个分类,而其非稀疏表示为1。

    import tensorflow as tf
     
    input_data = tf.Variable([[0.2, 0.1, 0.9], [0.3, 0.4, 0.6]], dtype=tf.float32)
    output = tf.nn.softmax_cross_entropy_with_logits(logits=input_data, labels=[[1,0,0],
                                                                                [0,1,0]])
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        print(sess.run(output))
    

    tf.nn.sparse_softmax_cross_entropy_with_logits()

    此函数大致与tf.nn.softmax_cross_entropy_with_logits的计算方式相同,
    适用于每个类别相互独立且排斥的情况,一幅图只能属于一类,而不能同时包含一条狗和一只大象

    但是在对于labels的处理上有不同之处,labels从shape来说此函数要求shape为[batch_size],
    labels[i]是[0,num_classes)的一个索引, type为int32或int64,即labels限定了是一个一阶tensor,
    并且取值范围只能在分类数之内,表示一个对象只能属于一个类别

    import tensorflow as tf
    
    input_data = tf.Variable([[0.2, 0.1, 0.9], [0.3, 0.4, 0.6]], dtype=tf.float32)
    output = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=input_data, labels=[0, 2])
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        print(sess.run(output))
    # [ 1.36573195  0.93983102]
    

    比tf.nn.softmax_cross_entropy_with_logits多了一步将labels稀疏化的操作。因为深度学习中,图片一般是用非稀疏的标签的,所以tf.nn.sparse_softmax_cross_entropy_with_logits()的频率比tf.nn.softmax_cross_entropy_with_logits高。

    不过两者输出尺寸等于输入shape去掉最后一维(上面输入[2*3],输出[2]),所以均常和tf.reduce_mean()连用。

  • 相关阅读:
    局域网无法访问vmware虚拟机WEB服务器解决办法
    zipimport.ZipImportError: can't decompress data; zlib not available 解决办法
    如何在win下使用linux命令
    《redisphp中文参考手册》php版
    Python关键字参数与非关键字参数(可变参数)详解
    Python与 PHP使用递归建立多层目录函数
    第一场个人图论专题
    poj_2762,弱连通
    word宏的问题
    fatal error LNK1123: 转换到 COFF 期间失败: 文件无效或损坏
  • 原文地址:https://www.cnblogs.com/hellcat/p/8568005.html
Copyright © 2011-2022 走看看