zoukankan      html  css  js  c++  java
  • pytorch 花式张量(Tensor)操作

     一、张量的维度操作

    1.squezee & unsqueeze

    x = torch.rand(5,1,2,1)
    x = torch.squeeze(x)#x.squeeze()去掉大小为1的维度,x.shape =(5,2)
    x = torch.unsqueeze(x,2)#x.unsqueeze(2)和squeeze相反在第三维上扩展,x.shape = (5,2,1)

    2.张量扩散,在指定维度上将原来的张量扩展到指定大小,比如原来x是31,输入size为[3, 4],可以将其扩大成34,4为原来1个元素的复制

    x = x.expand(*size)
    x = x.expand_as(y)# y:[3,4

    3.转置,torch.transpose 只能交换两个维度 permute没有限制

    x = torch.transpose(x, 1, 2) # 交换1和2维度
    x = x.permute(1, 2, 3, 0) # 进行维度重组

    4.改变形状,view&reshape 两者作用一样,区别在于是当从多的维度变到少的维度时,如果张量不是在连续内存存放,则view无法变成合并维度,会报错

    x = x.view(1, 2, -1)#把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其他维度的tensor
    x = x.reshape(1, 2, -1)

    5.张量拼接 cat & stack

    torch.cat(a_tuple, dim)#tuple 是一个张量或者元组,在指定维度上进行拼接
    torch.stack(a_tuple, dim)#与cat不同的在于,cat只能在原有的某一维度上进行连接,stack可以创建一个新的维度,将原有维度在这个维度上进行顺序排列
    #比如说,有2个4x4的张量,用cat就只能把它们变成一个8x4或4x8的张量,用stack可以变成2x4x4.

    6.张量拆分,chunk & split

    torch.chunk(a, chunk_num, dim)#在指定维度上将a变成chunk_num个大小相等的chunk,返回一个tuple。如果最后一个不够chunk_num,就返回剩下的
    torch.split(a, chunk_size, dim)#与chunk相似,只是第二次参数变成了chunk_size

    一、张量的乘法操作

    1 * 点乘 & torch.mul 两者用法相同,后者用了broadcast概念

    #标量k做*乘法的结果是Tensor的每个元素乘以k(相当于把k复制成与lhs大小相同,元素全为k的Tensor
    a = torch.ones(3,4)
    a = a * 2
    ''' tensor([[2., 2., 2., 2.],
            [2., 2., 2., 2.],
            [2., 2., 2., 2.]])'''
    #与行向量相乘向量作乘法 每列乘以行向量对应列的值(相当于把行向量的行复制,A的列数和向量数目相同),与列向量同理
    b = torch.Tensor([1,2,3,4])
    a*b
    '''
    tensor([[1., 2., 3., 4.],
         [1., 2., 3., 4.],
         [1., 2., 3., 4.]])
    '''
    # 向量*向量,element-wise product

    2.torch.mm&torch.matmul两者用法相同,后者用了broadcast概念

    torch.matmul(input, other, out=None) → Tensor
    #两个张量的矩阵乘积。行为取决于张量的维数,如下所示:
    #1. 如果两个张量都是一维的,则返回点积(标量)。
    # vector x vector
     tensor1 = torch.randn(3)
     tensor2 = torch.randn(3)
     torch.matmul(tensor1, tensor2).size()
    torch.Size([])
    #2. 如果两个参数都是二维的,则返回矩阵矩阵乘积。
    # matrix x matrix
     tensor1 = torch.randn(3, 4)
     tensor2 = torch.randn(4, 5)
     torch.matmul(tensor1, tensor2).size() #torch.Size([3, 5])
    #3. 如果第一个参数是一维的,而第二个参数是二维的,则为了矩阵乘法,会将1附加到其维数上。矩阵相乘后,将删除前置尺寸。
    # 也就是让tensor2变成矩阵表示,1x3的矩阵和 3x4的矩阵,得到1x4的矩阵,然后删除1
     tensor1 = torch.randn(3, 4)
     tensor2 = torch.randn(3)
     torch.matmul(tensor2, tensor1).size() #torch.Size([4])
    #4. 如果第一个参数为二维,第二个参数为一维,则返回矩阵向量乘积。
    # matrix x vector
     tensor1 = torch.randn(3, 4)
     tensor2 = torch.randn(4)
     torch.matmul(tensor1, tensor2).size()#torch.Size([3])
    #5. 如果两个自变量至少为一维且至少一个自变量为N维(其中N> 2),则返回批处理矩阵乘法。
    #如果第一个参数是一维的,则在其维数之前添加一个1,以实现批量矩阵乘法并在其后删除。
    #如果第二个参数为一维,则将1附加到其维上,以实现成批矩阵倍数的目的,然后将其删除。
    #非矩阵(即批量)维度可以被广播(因此必须是可广播的)。
    #例如,如果input为(jx1xnxm)张量,而other为(k×m×p)张量,out将是(j×k×n×p)张量。最后两维必须,满足矩阵乘法
     # batched matrix x broadcasted vector
     tensor1 = torch.randn(10, 3, 4)
     tensor2 = torch.randn(4)
     torch.matmul(tensor1, tensor2).size()#torch.Size([10, 3])
     # batched matrix x batched matrix
     tensor1 = torch.randn(10, 3, 4)
     tensor2 = torch.randn(10, 4, 5)
     torch.matmul(tensor1, tensor2).size()#torch.Size([10, 3, 5])
     # batched matrix x broadcasted matrix
     tensor1 = torch.randn(10, 3, 4)
     tensor2 = torch.randn(4, 5)
     torch.matmul(tensor1, tensor2).size()#torch.Size([10, 3, 5])
     tensor1 = torch.randn(10, 1, 3, 4)
     tensor2 = torch.randn(2, 4, 5)
     torch.matmul(tensor1, tensor2).size()#torch.Size([10, 2, 3, 5])

    3.通用乘法:torch.tensordot

    #可以表示任意多维,任意组合形式的矩阵相乘
    # 如果 a = torch.Tensor([1, 2, 3, 4]), b = torch.tensor([2, 3, 4, 5])
    # 想表示内积,直接令 dims=1 即可
    # 如果dimss=0则按照逐元素挨个相乘累加
    # dimss可以为二维数组,(dims_a, dims_b),指定两个张量任意维度相乘
    c = torch.tensordot(a, b, dims)
    # a: B N F  b: P F
    c = torch.tensordot(a,b,dims=([-1],[-1])) # c: B N P

    4.einsum

    #使用爱因斯坦求和约定来计算多线性表达式(即乘积和)的方法,能够以一种统一的方式表示各种各样的张量运算(内积、外积、转置、点乘、矩阵的迹、其他自定义运算)。
    #a: i k , b: j k
    c = torch.enisum('ik, jk -> ij', a,b) # c : i j 及为下面的公式

    其他可参考https://blog.csdn.net/a2806005024/article/details/96462827

  • 相关阅读:
    lecture 11.4
    lecture 10.30
    boolean functions and beyon
    lecture10.21
    golang hex to string
    golang中 将string转化为json
    ubuntu16报错: add-apt-repository command not found
    ubuntu16的防火墙关闭
    ubuntu 16 解决错误 manpath: can't set the locale; make sure $LC_* and $LANG are correct
    go get 安装时报 (https fetch: Get https://golang.org/x/crypto/acme/autocert?go-get=1: dial tcp 220.255.2.153:443: i/o timeout)
  • 原文地址:https://www.cnblogs.com/yutingmoran/p/11882816.html
Copyright © 2011-2022 走看看