zoukankan      html  css  js  c++  java
  • numpy.trace对于三维以上array的解析

    numpy.trace是求shape的对角线上的元素的和,具体看 https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.trace.html

    或者搜索 numpy.trace, 二维的比较好理解,对于三维以上的对角线(三维的对角线不止2条,该选哪两条呢)就不好理解了,以下是本人的理解

    # 3-D array 的trace算法 
    
    import numpy as np
    
    a = np.arange(8).reshape((2,2,2))
    print 'a =',a
    
    print np.trace(a)
    
    x0 = a[0,0,0] + a[1,1,0]#即 0 + 6
    print 'a[1,1,0] =', a[1,1,0],';','x0 =',x0
    x1 = a[0,0,1] + a[1,1,1]#即 1 + 7
    print 'a[1,1,1] =', a[1,1,1],';','x1 =',x1
    
    #当然你可能认为6 = 2 + 4;或者 8 = 5 + 3,确实一个立方体应该是4条对角线,但是是不是都可以呢
    #请看下面的3-D
    b = np.array([ [ [100, 198],
                     [2, 3]],
                   [ [4, 5],
                    [6,7]]])
    print 'b =',b
    print np.trace(b)
    y0 = b[0,0,0] + b[1,1,0]#即 0 + 6
    print 'y0 = b[0,0,0] + b[1,1,0] = %d + %d = %d' % (b[0,0,0], b[1,1,0], y0)
    y1 = b[0,0,1] + b[1,1,1]#即 1 + 7
    print 'y1 = b[0,0,1] + b[1,1,1] = %d + %d = %d' % (b[0,0,1], b[1,1,1], y1)
    
    #事实证明只能是其中的固定方向的
    #之所以能以此种方式思考是因为2-D的直观
    
    c = np.array([[2, 8],
                 [4,5]])
    print 'c =', c
    print 'trace =', np.trace(c)
    trace_c = c[0,0] + c[1,1]
    print 'c[0,0] + c[1,1] = %d + %d = %d' % (c[0,0], c[1,1], trace_c)
    
    #来看看4-D的
    d = np.arange(32).reshape((2,2,2,4))
    print 'd =', d
    #猜猜看这个trace结果是什么shape,
    #(2, 4),只要去掉前面 2个维度即可
    print 'np.trace(d).shape =', np.trace(d).shape
    print 'np.trace(d) =', np.trace(d)
    td00 = d[0,0,0,0] + d[1,1,0,0]
    print 'td00 = ' + 'd[0,0,0,0] + d[1,1,0,0] = %d + %d = %d' % (d[0,0,0,0], d[1,1,0,0], td00)
    td01 = d[0,0,0,1] + d[1,1,0,1]
    print 'td01 = ' + 'd[0,0,0,1] + d[1,1,0,1] = %d + %d = %d' % (d[0,0,0,1], d[1,1,0,1], td01)
    td02 = d[0,0,0,2] + d[1,1,0,2]
    print 'td02 = ' + 'd[0,0,0,2] + d[1,1,0,2] = %d + %d = %d' % (d[0,0,0,2], d[1,1,0,2], td02)
    td03 = d[0,0,0,3] + d[1,1,0,3]
    print 'td03 = ' + 'd[0,0,0,3] + d[1,1,0,3] = %d + %d = %d' % (d[0,0,0,3], d[1,1,0,3], td03)
    
    print
    td10 = d[0,0,1,0] + d[1,1,1,0]
    print 'td10 = ' + 'd[0,0,1,0] + d[1,1,1,0] = %d + %d = %d' % (d[0,0,1,0], d[1,1,1,0], td10)
    td11 = d[0,0,1,1] + d[1,1,1,1]
    print 'td11 = ' + 'd[0,0,1,1] + d[1,1,1,1] = %d + %d = %d' % (d[0,0,1,1], d[1,1,1,1], td11)
    td12 = d[0,0,1,2] + d[1,1,1,2]
    print 'td12 = ' + 'd[0,0,1,2] + d[1,1,1,2] = %d + %d = %d' % (d[0,0,1,2], d[1,1,1,2], td12)
    td13 = d[0,0,1,3] + d[1,1,1,3]
    print 'td13 = ' + 'd[0,0,1,3] + d[1,1,1,3] = %d + %d = %d' % (d[0,0,1,3], d[1,1,1,3], td13)

    以下是运行结果:(python 2.7, numpy:1.14.2:

    a = [[[0 1]
      [2 3]]
    
     [[4 5]
      [6 7]]]
    [6 8]
    a[1,1,0] = 6 ; x0 = 6
    a[1,1,1] = 7 ; x1 = 8
    b = [[[100 198]
      [  2   3]]
    
     [[  4   5]
      [  6   7]]]
    [106 205]
    y0 = b[0,0,0] + b[1,1,0] = 100 + 6 = 106
    y1 = b[0,0,1] + b[1,1,1] = 198 + 7 = 205
    c = [[2 8]
     [4 5]]
    trace = 7
    c[0,0] + c[1,1] = 2 + 5 = 7
    d = [[[[ 0  1  2  3]
       [ 4  5  6  7]]
    
      [[ 8  9 10 11]
       [12 13 14 15]]]
    
    
     [[[16 17 18 19]
       [20 21 22 23]]
    
      [[24 25 26 27]
       [28 29 30 31]]]]
    np.trace(d).shape = (2, 4)
    np.trace(d) = [[24 26 28 30]
     [32 34 36 38]]
    td00 = d[0,0,0,0] + d[1,1,0,0] = 0 + 24 = 24
    td01 = d[0,0,0,1] + d[1,1,0,1] = 1 + 25 = 26
    td02 = d[0,0,0,2] + d[1,1,0,2] = 2 + 26 = 28
    td03 = d[0,0,0,3] + d[1,1,0,3] = 3 + 27 = 30
    
    td10 = d[0,0,1,0] + d[1,1,1,0] = 4 + 28 = 32
    td11 = d[0,0,1,1] + d[1,1,1,1] = 5 + 29 = 34
    td12 = d[0,0,1,2] + d[1,1,1,2] = 6 + 30 = 36
    td13 = d[0,0,1,3] + d[1,1,1,3] = 7 + 31 = 38
  • 相关阅读:
    03_已解决 [salt.master :2195][ERROR ][6219] Failed to allocate a jid. The requested returner 'mysql' could not be loaded.
    02_已解决 [salt.minion :1758][ERROR ][52886] Returner mysql.returner could not be loaded: 'mysql' __virtual__ returned False: Could not import mysql returner; mysql python client is not installed.
    05_centos7安装python3
    04_mysql安装
    03_mysql-python模块, linux环境下python2,python3的
    02_pip区别: linux环境下python2,python3的
    01 salt平台,软件架构图
    01_初识redis
    list_for_each_entry()函数分析
    趣解什么是网关
  • 原文地址:https://www.cnblogs.com/mengshu-lbq/p/8615873.html
Copyright © 2011-2022 走看看