zoukankan      html  css  js  c++  java
  • TensorFlow、Numpy中的axis的理解

    TensorFlow中有很多函数涉及到axis,比如tf.reduce_mean(),其函数原型如下:

    1 def reduce_mean(input_tensor,
    2                 axis=None,
    3                 keepdims=None,
    4                 name=None,
    5                 reduction_indices=None,
    6                 keep_dims=None):

    其中axis表示的是,对该维度进行求均值(默认情况下,是对所有值求均值)。
    除了TensorFlow中,numpy中也经常遇到很多对矩阵操作的函数会涉及axis操作。比如np.mean(),其函数原型如下:

    1 def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue):

    想要弄清楚如何处理涉及axis(维度)的操作,必须先明白axis是什么。
    首先axis是维度,如果axis=0则对应着高; 如果axis=1则对应着行处理;如果axis=2则对应着列;如果axis=3…n(无法用直观的图来表示)。我相信很多人看到这还是会一头雾水。什么是高,行还有列。为了说明这个问题,我举个列子:

    data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]]
    data_np=np.array(data)
    print(data_np)
    [[[  1   2   3]
      [ 11  22  33]]
    
     [[  4   5   6]
      [ 44  55  66]]
    
     [[ 10  11  12]
      [100 110 120]]
    
     [[  7   8   9]
      [ 77  88  99]]]
      
    如上面,可以将最外层[ ]去掉,可以发现有4组元素(这里的元素是矩阵),你可以将其理解为高。
    再从这3组元素中选取一组,比如选择的是
    [[  1   2   3]
      [ 11  22  33]]
    然后将该组的最外层[ ]去掉,可以发现有2组元素分别为[  1   2   3]和 [ 11  22  33],此时对应的是行。
    在从这两组元素中选组一组,比如选择的是
     [ 11  22  33]
     现在无需去掉最外层的[ ]了,一眼就能看出里面有3个元素。这就是对应的列。
     理解了上面的分析后,很容易就知道(高,行,列)对应的其实就是改矩阵的shape.
    print(data_np.shape):
    (4,2,3)

    现在弄清楚了axis的值与(高,行,列)的关系后,再来分析tf.reduce_mean()或者np.mean()等函数是如何对axis进行操作的。

     1 data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]]
     2 
     3 data_tensor=tf.constant(data,dtype=tf.float32)
     4 
     5 mean_axis0=tf.reduce_mean(data_tensor,axis=0)
     6 mean_axis1=tf.reduce_mean(data_tensor,axis=1)
     7 mean_axis2=tf.reduce_mean(data_tensor,axis=2)
     8 
     9 with tf.Session() as sess:
    10     print(sess.run(mean_axis0))
    11     print(sess.run(mean_axis1))
    12     print(sess.run(mean_axis2))

    针对上述代码,我们先对axis=0维度的数据处理进行分析。
    首先对上述data数据进行立体化变换,如下图(本人本想用软件来绘制3D的矩阵叠加效果,可惜找了很多软件都不适合,也许是本人寻找的还不够,欢迎有知道可以绘制3D的矩阵叠加效果的朋友们,能够分享一下。感激…)

    如上如,axis=0的维度数据求均值,

    [[(1+4+10+7)/4         (2+5+11+8)/4       (3+6+12+9)/4]
    [(11+44+100+77)/4      (22+55+110+88)/4   (33+66+120+99)/4]]
    =
    [[ 5.5   6.5   7.5 ]
     [58.   68.75 79.5 ]]

    同理,对axis=1的维度数据求均值,

    [[(1+11)/2    (2+22)/2    (3+33)/2]
     [(4+44)/2    (5+55)/2    (6+66)/2]
     [(10+100)/2  (11+110)/2  (12+120)/2]
     [(7+77)/2    (8+88)/2    (9+99)/2]]
     =
     [[ 6.  12.  18. ]
     [24.  30.  36. ]
     [55.  60.5 66. ]
     [42.  48.  54. ]]

    同理可得axis=2维度的数据平均值为(过程留给读者去推,运算结果如下):

    [[  2.  22.]
     [  5.  55.]
     [ 11. 110.]
     [  8.  88.]]

    在python的世界里,有很多时候都需要对数据进行维度的操作,如果对axis理解的不透的话,很容易找不着方向。

    更多干货请关注:

  • 相关阅读:
    Two Sum II
    Subarray Sum
    Intersection of Two Arrays
    Reorder List
    Convert Sorted List to Binary Search Tree
    Remove Duplicates from Sorted List II
    Partition List
    Linked List Cycle II
    Sort List
    struts2结果跳转和参数获取
  • 原文地址:https://www.cnblogs.com/RoseVorchid/p/10633299.html
Copyright © 2011-2022 走看看