zoukankan      html  css  js  c++  java
  • tf.nn.softmax & tf.nn.reduce_sum & tf.nn.softmax_cross_entropy_with_logits

    tf.nn.softmax

    softmax是神经网络的最后一层将实数空间映射到概率空间的常用方法,公式如下:

    [softmax(x)_i=frac{exp(x_i)}{sum_jexp(x_j)} ]

    本文意于分析tensorflow中的tf.nn.softmax(),关于softmax的具体推导和相关知识点,参照其它文章

    tensorflow的tf.nn.softmax()函数实现位于这里,可以看到,实现起来相当简明:

    tf.exp(logits)/tf.reduce_sum(tf.exp(logits),axis)
    

    看一个例子:

    x=tf.constant([[[1.0,2.0],[3.0,4.0]],
    [[5.0,6.0],[7.0,9.0]],
    [[9.0,10.0],[11.0,12.0]]])
    
    with tf.Session() as sess:
    	print(sess.run(tf.nn.softmax(x,axis=0)))
    	print(sess.run(tf.nn.softmax(x,axis=1)))
    	print(sess.run(tf.nn.softmax(x,axis=2)))
    

    这里主要关注axis参数,它表示在那个维度上做softmax。从上面可以看到,axis参数传递给了tf.reduce_sum。上述的运行结果类似于:

    axis=0:
    [[[3.2932041e-04 3.2932041e-04]
      [3.2932041e-04 3.2932041e-04]]
     [[1.7980287e-02 1.7980287e-02]
      [1.7980287e-02 1.7980287e-02]]
     [[9.8169035e-01 9.8169035e-01]
      [9.8169035e-01 9.8169035e-01]]]
    
    axis=1:
    [[[0.11920291 0.11920291]
      [0.880797   0.880797  ]]
     [[0.11920291 0.11920291]
      [0.880797   0.880797  ]]
     [[0.11920291 0.11920291]
      [0.880797   0.880797  ]]]
    
    axis=2:
    [[[0.26894143 0.7310586 ]
      [0.26894143 0.7310586 ]]
     [[0.26894143 0.7310586 ]
      [0.26894143 0.7310586 ]]
     [[0.26894143 0.7310586 ]
      [0.26894143 0.7310586 ]]]
    

    这里以axis=0为例,tf.reduce_sum(tf.exp(x),axis=0)的结果为:

    [[  8254.216  22437.285]
     [ 60990.863 165790.34 ]]
    

    tf.exp(x)的结果为:

    [[[2.7182817e+00 7.3890562e+00]
      [2.0085537e+01 5.4598152e+01]]
     [[1.4841316e+02 4.0342880e+02]
      [1.0966332e+03 2.9809580e+03]]
     [[8.1030840e+03 2.2026467e+04]
      [5.9874145e+04 1.6275478e+05]]]
    

    假设最外层axis=0的维度表示样本数,取出第一个样本看其计算过程,可知:

    [[3.2932041e-04 3.2932041e-04]
      [3.2932041e-04 3.2932041e-04]]=
    [[2.7182817e+00 7.3890562e+00]
      [2.0085537e+01 5.4598152e+01]]
     /
     [[  8254.216  22437.285]
     [ 60990.863 165790.34 ]]
    

    也就是样本概率加和为1,也就是对axis=0处做softmax(axis=0维度上,概率加和为1),而其“内部”的值一样。

    tf.reduce_sum

    这里从tf.reduce_sum函数这个角度提一下,tensorflow中的维度这个参数。axis这个参数可以从张量从外向里看,axis=0表示最外一层,举例而言:

    x=tf.constant([[[1.0,2.0],[3.0,4.0]],
    				[[5.0,6.0],[7.0,9.0]],
    				[[9.0,10.0],[11.0,12.0]]])
    
    with tf.Session() as sess:
    	print(sess.run(tf.reduce_sum(x,axis=0)))
    

    上述这个例子中,x的shape为[3,2,2]。最外层的张量有3个元素,现在要对最外层(也就是axis=0)reduce_sum,也就是:

    [[1.0,2.0],[3.0,4.0]]+[[5.0,6.0],[7.0,9.0]]+[[9.0,10.0],[11.0,12.0]]
    =[[15.0,18.0],[21.0,24.0]]
    

    3维张量内部的2维张量,对应位置相加。例如:15.0=1.0+5.0+9.0

    同样的:

    with tf.Session() as sess:
    	print(sess.run(tf.reduce_sum(x,axis=1)))
    	print(sess.run(tf.reduce_sum(x,axis=2)))
    

    axis=1时,是第二层,第二层中每个张量有2个元素,对于第一个第二层(axis=1)张量[[1.0,2.0],[3.0,4.0]],现在要对其reduce_sum,运算过程如下:

    [1.0,2.0]+[3.0,4.0]=[4.0,6.0]
    

    第二个第二层张量和第三个第二层张量运算过程:

    [5.0,6.0]+[7.0,8.0]=[12.0,14.0]
    [[9.0,10.0]+[11.0,12.0]]=[20.0,22.0]
    

    拼合起来结果就是:

    [[ 4.  6.]
     [12. 14.]
     [20. 22.]]
    

    当axis=2时,也就是

    with tf.Session() as sess:
    	print(sess.run(tf.reduce_sum(x,axis=2)))
    

    结果是什么呢?

    [[ 3.  7.]
     [11. 15.]
     [19. 23.]]
    

    其中,3.0=1.0+2.0

    tf.nn.softmax_cross_entropy_with_logits

    一般我们在用softmax做最后一层,计算loss时常常用到该函数,函数签名:

    tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)
    
    • logits:神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes]。单样本的话,大小就是num_classes
    • labels:标签,大小要与logits保持一致

    计算过程分为2步:

    • 对网络最后一层的输出做一个softmax,这一步通常是求取输出属于某一类的概率,output_shape: [None, num_classes]

    • 对每个样本,使用神经网络的输出和真实标签做交叉熵,交叉熵公式如下:

      [H_{y'}(y)=-sum_iy_i'log(y_i) ]

      对单个样本而言,(y_i')是真实标签第i维的值,(y_i)是神经网络输出的向量的第i维的值。可以看到,(y_i')(y_i)越一致,交叉熵越小,所以可以使用交叉熵作为loss。交叉熵可以参考其它文章

      该函数返回向量,要求标量交叉熵,可以使用tf.reduce_sum将其变为标量

  • 相关阅读:
    devexpress toolbar 填充整行宽度
    2. Rust的三板斧 安全,迅速,并发
    1. rust的优点
    谈谈我对sku的理解(3)----页面效果
    谈谈我对sku的理解(2)----数据库设计
    谈谈我对sku的理解(1)
    我眼里的奇酷手机360OS
    Oracle中的wm_concat()函数
    获取java本地系统信息 Properties
    java 获取用户的ip都是 127.0.0.1
  • 原文地址:https://www.cnblogs.com/mengnan/p/9330513.html
Copyright © 2011-2022 走看看