zoukankan      html  css  js  c++  java
  • Tensorflow中的tf.argmax()函数

    转载请注明出处:http://www.cnblogs.com/willnote/p/6758953.html

    官方API定义

    tf.argmax(input, axis=None, name=None, dimension=None)

    Returns the index with the largest value across axes of a tensor.
    Args:

    • input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, uint16, int16, int8, complex64, complex128, qint8, quint8, qint32, half.
    • axis: A Tensor. Must be one of the following types: int32, int64. int32, 0 <= axis < rank(input). Describes which axis of the input Tensor to reduce across. For vectors, use axis = 0.
    • name: A name for the operation (optional).

    Returns:

    • A Tensor of type int64.

    关于axis

    定义中的axis与numpy中的axis是一致的,下面通过代码进行解释

    import numpy as np
    import tensorflow as tf
    
    sess = tf.session()
    m = sess.run(tf.truncated_normal((5,10), stddev = 0.1) )
    print type(m)
    print m
    
    -------------------------------------------------------------------------------
    <type 'numpy.ndarray'>
    [[ 0.09957541 -0.0965599   0.06064715 -0.03011306  0.05533558  0.17263047
      -0.02660419  0.08313394 -0.07225946  0.04916157]
     [ 0.11304571  0.02099175  0.03591062  0.01287777 -0.11302195  0.04822164
      -0.06853487  0.0800944  -0.1155676  -0.01168544]
     [ 0.15760773  0.05613248  0.04839646 -0.0218203   0.02233066  0.00929849
      -0.0942843  -0.05943     0.08726917 -0.059653  ]
     [ 0.02553608  0.07298559 -0.06958302  0.02948747  0.00232073  0.11875584
      -0.08325859 -0.06616175  0.15124641  0.09522969]
     [-0.04616683  0.01816062 -0.10866459 -0.12478453  0.01195056  0.0580056
      -0.08500613  0.00635608 -0.00108647  0.12054099]]
    

    m是一个5行10列的矩阵,类型为numpy.ndarray

    #使用tensorflow中的tf.argmax()
    col_max = sess.run(tf.argmax(m, 0) )  #当axis=0时返回每一列的最大值的位置索引
    print col_max
    
    row_max = sess.run(tf.argmax(m, 1) )  #当axis=1时返回每一行中的最大值的位置索引
    print row_max
    
    array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
    array([5, 0, 0, 8, 9])
    
    -------------------------------------------------------------------------------
    #使用numpy中的numpy.argmax
    row_max = m.argmax(0)
    print row_max
    
    col_max = m.argmax(1)
    print col_max
    
    array([2, 3, 0, 3, 0, 0, 0, 0, 3, 4])
    array([5, 0, 0, 8, 9])
    

    可以看到tf.argmax()与numpy.argmax()方法的用法是一致的

    • axis = 0的时候返回每一列最大值的位置索引
    • axis = 1的时候返回每一行最大值的位置索引
    • axis = 2、3、4...,即为多维张量时,同理推断

    参考

    1. Tensorflow官方API tf.argmax说明
    2. Numpy官方AIP numpy.argmax说明
  • 相关阅读:
    HDU 3572 Task Schedule(拆点+最大流dinic)
    POJ 1236 Network of Schools(Tarjan缩点)
    HDU 3605 Escape(状压+最大流)
    HDU 1166 敌兵布阵(分块)
    Leetcode 223 Rectangle Area
    Leetcode 219 Contains Duplicate II STL
    Leetcode 36 Valid Sudoku
    Leetcode 88 Merge Sorted Array STL
    Leetcode 160 Intersection of Two Linked Lists 单向链表
    Leetcode 111 Minimum Depth of Binary Tree 二叉树
  • 原文地址:https://www.cnblogs.com/willnote/p/6758953.html
Copyright © 2011-2022 走看看