zoukankan      html  css  js  c++  java
  • 笔记 EINSUM IS ALL YOU NEED

    原文 https://rockt.github.io/2018/04/30/einsum

    就是说有一种运算,叫做einsum,可以做各种矩阵和向量的运算,而且特别简洁和优美

    自己跑一下里面的例子,就知道是怎么回事了,

    这里记录一下其中的tensor contraction,算是最general的形式了

    先看 torch.einsum('ij,ij->', [a, b]) 是什么意思?

    import torch
    
    a = torch.arange(2*3).reshape(2, 3)
    b = torch.arange(2*3).reshape(2, 3)
    x = torch.einsum('ij,ij->', [a, b])
    print(a)
    print(b)
    print(x)
    
    res = 0
    for i in range(2):
        for j in range(3):
                res += a[i,j] * b[i,j]
    
    print(res)

    结果:

    (deeplearning) ➜  Catchfish python einsum_test.py
    tensor([[0, 1, 2],
            [3, 4, 5]])
    tensor([[0, 1, 2],
            [3, 4, 5]])
    tensor(55)
    tensor(55)

    相当于把对应位置相乘再相加,这样二维空间收缩为1个值

    三维矩阵的收缩同理,torch.einsum('ijk,ijk->', [a, b]) 是什么意思?

    其实二维矩阵的乘法也是tensor contraction,只不过只是将其中一维收缩,torch.einsum('ik,kj->ij', [a, b])

    能收缩的条件是:只要对应维的长度相同即可

    前面的讲完了,重点是高维矩阵是如何收缩的?

    例子:

    内部是怎么运算的呢?相同维数的3和5进行了收缩,相当于2,7,11,13,17固定

    验证一下:取出一个固定状态,将相同的那两维收缩,与之前整体收缩再取同一状态对比,发现两个值一样

    import torch
    
    a = torch.arange(2*3*5*7).reshape(2,3,5,7)
    b = torch.arange(11*13*3*17*5).reshape(11,13,3,17,5)
    x = torch.einsum('pqrs,tuqvr->pstuv', [a, b])
    print(x.shape)
    
    m1 = a[1, :, :, 5]
    m2 = b[6, 7, :, 8, :]
    res = torch.einsum("ij,ij->", [m1, m2])
    print(res)
    print(x[1, 5, 6, 7, 8])

    结果:

    (deeplearning) ➜  Catchfish python tensor_contraction.py 
    torch.Size([2, 7, 11, 13, 17])
    tensor(52027730)
    tensor(52027730)

    维度的计算:相同维数的收缩了,剩下的各个维数组成结果的维数

    自己可以试一下,收缩三个及更高的维数也是一样的做法。

    个性签名:时间会解决一切
  • 相关阅读:
    ThinkPHP安全规范指引
    正则表达式:不能包含某些特殊字符串
    头晕是因为虚 ( ̄^ ̄゜)
    vs code中文扩展包
    table-cell width:1% 深入理解
    C#使用NPOI操作Excel
    C#利用LumenWorks.Framework.IO.Csv读取CSV文件
    发送邮件代码
    .net 集合详解
    EF Code First:数据更新最佳实践
  • 原文地址:https://www.cnblogs.com/lfri/p/15473640.html
Copyright © 2011-2022 走看看