zoukankan      html  css  js  c++  java
  • tf.nn.sigmoid_cross_entropy_with_logits 分类

    tf.nn.sigmoid_cross_entropy_with_logits(_sentinel=None,,labels=None,logits=None,name=None)
    
    logits和labels必须有相同的类型和大小
    
    参数:
    _sentinel:内部的并不使用
    labels:和logits的shape和type一样
    logits:类型为float32或者float64
    name:操作的名称,可省
    
    返回的是:一个张量,和logits的大小一致。是逻辑损失
    
    

    sample

    import numpy as np
    import tensorflow as tf
     
    labels=np.array([[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]])
    logits=np.array([[11.,8.,7.],[10.,14.,3.],[1.,2.,4.]])
    
    y_pred=tf.math.sigmoid(logits)
    prob_error1=-labels*tf.math.log(y_pred)-(1-labels)*tf.math.log(1-y_pred)
     
    labels1=np.array([[0.,1.,0.],[1.,1.,0.],[0.,0.,1.]])#不一定只属于一个类别
    logits1=np.array([[1.,8.,7.],[10.,14.,3.],[1.,2.,4.]])
    y_pred1=tf.math.sigmoid(logits1)
    prob_error11=-labels1*tf.math.log(y_pred1)-(1-labels1)*tf.math.log(1-y_pred1)
     
    with tf.compat.v1.Session() as sess:
        print("1:")
        print(sess.run(prob_error1))
        print("2:")
        print(sess.run(prob_error11))
        print("3:")
        print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,logits=logits)))
        print("4:")
        print(sess.run(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels1,logits=logits1)))
    

    output

    1和3,2和4结果一样
    1:
    [[1.67015613e-05 8.00033541e+00 7.00091147e+00]
     [1.00000454e+01 8.31528373e-07 3.04858735e+00]
     [1.31326169e+00 2.12692801e+00 1.81499279e-02]]
    2:
    [[1.31326169e+00 3.35406373e-04 7.00091147e+00]
     [4.53988992e-05 8.31528373e-07 3.04858735e+00]
     [1.31326169e+00 2.12692801e+00 1.81499279e-02]]
    3:
    [[1.67015613e-05 8.00033541e+00 7.00091147e+00]
     [1.00000454e+01 8.31528373e-07 3.04858735e+00]
     [1.31326169e+00 2.12692801e+00 1.81499279e-02]]
    4:
    [[1.31326169e+00 3.35406373e-04 7.00091147e+00]
     [4.53988992e-05 8.31528373e-07 3.04858735e+00]
     [1.31326169e+00 2.12692801e+00 1.81499279e-02]]
    
    
  • 相关阅读:
    java JDBC DAO ORM Domain
    《硅谷钢铁侠-- 埃隆·马斯克的冒险人生》
    在启动MYSQL时出现问题:“ERROR 2003 (HY000): Can't connect to MySQL server on 'localhost' (10061)”
    使用IntelliJ IDEA开发java web
    [django]用日期来查询datetime类型字段
    2020/5/31
    图解排序算法(三)之堆排序
    图解排序算法(二)之希尔排序
    图解排序算法(一)之3种简单排序(选择,冒泡,直接插入)
    Oracle约束(Constraint)详解
  • 原文地址:https://www.cnblogs.com/smallredness/p/11199541.html
Copyright © 2011-2022 走看看