zoukankan      html  css  js  c++  java
  • pytorch 中tensor的加减和mul、matmul、bmm

    如下是tensor乘法与加减法,对应位相乘或相加减,可以一对多

    import torch
    def add_and_mul():
        x = torch.Tensor([[[1, 2, 3],
                           [4, 5, 6]],
    
                          [[7, 8, 9],
                           [10, 11, 12]]])
        y = torch.Tensor([1, 2, 3])
        y = y - x
        print(y)
        '''
        tensor([[[ 0.,  0.,  0.],
             [-3., -3., -3.]],
    
            [[-6., -6., -6.],
             [-9., -9., -9.]]])
        '''
        t = 1. - x.sum(dim=1)
        print(t)
        '''
        tensor([[ -4.,  -6.,  -8.],
            [-16., -18., -20.]])
        '''
        y = torch.Tensor([[1, 2, 3],
                          [4, 5, 6]])
        y = torch.mul(y,x) #等价于此方法 y*x
        print(y)
        '''
        tensor([[[ 1.,  4.,  9.],
             [16., 25., 36.]],
    
            [[ 7., 16., 27.],
             [40., 55., 72.]]])
        '''
        z = x ** 2
        print(z)
        """
        tensor([[[  1.,   4.,   9.],
             [ 16.,  25.,  36.]],
    
            [[ 49.,  64.,  81.],
             [100., 121., 144.]]])
        """
    
    if __name__=='__main__':
        add_and_mul()

    矩阵的乘法,matmul和bmm的具体代码

    import torch
    
    def matmul_and_bmm():
        # a=(2*3*4)
        a = torch.Tensor([[[1, 2, 3, 4],
                           [4, 0, 6, 0],
                           [3, 2, 1, 4]],
                          [[3, 2, 1, 0],
                           [0, 3, 2, 2],
                           [1, 2, 1, 0]]])
        # b=(2,2,4)
        b = torch.Tensor([[[1, 2, 3, 4],
                           [4, 0, 6, 0]],
                          [[3, 2, 1, 0],
                           [1, 2, 1, 0]]])
    
        b=b.transpose(1, 2)
        # res=(2,3,2),对于a*b,是第一维度不变,而后[3,4] x [4,2]=[3,2]
        #res[0,:]=a[0,:] x b[0,;];   res[1,:]=a[1,:] x b[1,;] 其中x表示矩阵乘法
        res = torch.matmul(a, b)  # 维度res=[2,3,2]
        res2 = torch.bmm(a, b)  # 维度res2=[2,3,2]
        print(res)  # res2的值等于res
        """
        tensor([[[30., 22.],
                 [22., 52.],
                 [26., 18.]],
    
                [[14.,  8.],
                 [ 8.,  8.],
                 [ 8.,  6.]]])
        """
    
    if __name__=='__main__':
        matmul_and_bmm()
  • 相关阅读:
    实验一框架选择及其分析
    站立会议(一)
    关于有多少个1的计算
    寻找水王问题
    如何买到更便宜的书
    NABCD
    二维数组首尾相连求最大子矩阵
    环数组求最大子数组的和
    二维数组求最大矩阵
    关于铁道大学基础教学楼电梯调查
  • 原文地址:https://www.cnblogs.com/AntonioSu/p/12021366.html
Copyright © 2011-2022 走看看