zoukankan      html  css  js  c++  java
  • einsum详解

    参考 http://www.elecfans.com/d/779631.html

    https://blog.csdn.net/ashome123/article/details/117110042

    import torch
    from torch import einsum
    
    a = torch.arange(3)  # [0, 1, 2]
    b = torch.arange(3, 6)  # [3, 4, 5]
    
    c = torch.arange(0, 6).reshape(2, 3)  # [[0, 1, 2], [3, 4, 5]]
    d = torch.arange(0, 6).reshape(3, 2)  # [[0, 1], [2, 3], [4, 5]]
    
    # 转置
    # print(einsum('ij->ji', c))  # tensor([0, 1, 2])
    
    # 求和
    # print(einsum('i', a))  # tensor([0, 1, 2])
    
    # print(einsum('ij->i', c))  # 按行tensor([ 3, 12])
    
    # print(einsum('ij->j', c))  # 按列tensor([ 3, 5, 7])
    
    # 矩阵与向量乘法
    # 参数加不加[ ] 好像无所谓
    # print(torch.einsum('ik,k->i', c, a))  # tensor([ 5, 14]) 按行乘以一行
    
    # print(torch.einsum('ik,i->k', [d, a]))  # tensor([ 10, 13]) 按列乘以一行
    
    # 矩阵与矩阵乘法
    # print(torch.einsum('ik,kj->ij', [c, d]))  # tensor([[10, 13],[28, 40]])
    
    # 点积
    # print(einsum('i,i->', [a, b])) # tensor(14)
    
    # print(torch.einsum('ij,ij->', [c, c])) # tensor(55)
    
    # print(torch.einsum('ij,ij-> ij', [c, c]))  # tensor([[ 0,  1,  4], [ 9, 16, 25]])
    
    # 外积
    # tensor([[ 0,  0,  0],
    #         [ 3,  4,  5],
    #         [ 6,  8, 10]])
    # print(torch.einsum('i,j->ij', [a, b])) # tensor(55)
    
    # batch矩阵相乘
    # x = torch.arange(12).reshape(2,2,3)
    # xx = torch.arange(12).reshape(2,3,2)
    # print(x)
    # print(xx)
    # print(torch.einsum('ijk,ikl->ijl', x, xx))
    # tensor([[[ 0,  1,  2],
    #          [ 3,  4,  5]],
    #         [[ 6,  7,  8],
    #          [ 9, 10, 11]]])
    #
    # tensor([[[ 0,  1],
    #          [ 2,  3],
    #          [ 4,  5]],
    #         [[ 6,  7],
    #          [ 8,  9],
    #          [10, 11]]])
    #
    # tensor([[[ 10,  13],
    #          [ 28,  40]],
    #         [[172, 193],
    #          [244, 274]]])
    
    # 张量缩约
    # aa = torch.randn(2,3,5,7)
    # bb = torch.randn(11,13,3,17,5)
    # print(torch.einsum('pqrs,tuqvr->pstuv', [aa, bb]).shape)
    # torch.Size([2, 7, 11, 13, 17])
    
    
    # 双线性变换
    # a = torch.randint(1,5, size=(2, 3))
    # b = torch.randint(1,5, size=(4, 3, 6))
    # c = torch.randint(1,5, size=(2, 6))
    # print(a)
    # print(b)
    # print(c)
    # print(torch.einsum('ik,jkl,il->ij', [a, b, c]))
    # tensor([[4, 4, 4],
    #         [2, 2, 4]])
    # tensor([[[3, 3, 2, 2, 2, 4],
    #          [1, 2, 4, 3, 1, 1],
    #          [4, 3, 3, 3, 1, 1]],
    #         [[1, 2, 2, 2, 2, 3],
    #          [3, 2, 3, 3, 1, 1],
    #          [1, 3, 3, 3, 3, 3]],
    #         [[1, 3, 4, 3, 1, 2],
    #          [4, 4, 4, 1, 3, 4],
    #          [4, 4, 3, 3, 1, 3]],
    #         [[2, 1, 1, 4, 1, 2],
    #          [4, 1, 1, 2, 4, 3],
    #          [2, 3, 4, 2, 3, 3]]])
    # tensor([[2, 2, 1, 3, 2, 4],
    #         [3, 2, 1, 2, 1, 2]])
    # tensor([[388, 384, 472, 416],
    #         [222, 200, 266, 218]])
    












    种一棵树最好的时间是十年前,其次是现在。
  • 相关阅读:
    stack.pop()和stack.peek()的区别
    信号与系统,系统函数的影响
    java中short、int、long、float、double取值范围
    Spring从容器获得组件的方法
    Eclipse中项目的类路径文件夹
    Math的常用方法
    spring基本入门步骤
    opencv入门
    make和cmake构建工具
    使用eclipse开发c++程序及开发环境搭建
  • 原文地址:https://www.cnblogs.com/islch/p/15466881.html
Copyright © 2011-2022 走看看