zoukankan      html  css  js  c++  java
  • pytorch 中改变tensor维度(transpose)、拼接(cat)、压缩(squeeze)详解

    具体示例如下,注意观察维度的变化

    1.改变tensor维度的操作:transpose、view、permute、t()、expand、repeat

    #coding=utf-8
    import  torch
    
    def change_tensor_shape():
        x=torch.randn(2,4,3)
        s=x.transpose(1,2) #shape=[2,3,4]
        y=x.view(2,3,4) #shape=[2,3,4]
        z=x.permute(0,2,1) #shape=[2,3,4]
    
        #tensor.t()只能转化 a 2D tensor
        m=torch.randn(2,3)#shape=[2,3]
        n=m.t()#shape=[3,2]
        print(m)
        print(n)
    
        #返回当前张量在某个维度为1扩展为更大的张量
        x = torch.Tensor([[1], [2], [3]])#shape=[3,1]
        t=x.expand(3, 4)
        print(t)
        '''
        tensor([[1., 1., 1., 1.],
            [2., 2., 2., 2.],
            [3., 3., 3., 3.]])
        '''
    
        #沿着特定的维度重复这个张量
        x=torch.Tensor([[1,2,3]])
        t=x.repeat(3, 2)
        print(t)
        '''
        tensor([[1., 2., 3., 1., 2., 3.],
            [1., 2., 3., 1., 2., 3.],
            [1., 2., 3., 1., 2., 3.]])
        '''
        x = torch.randn(2, 3, 4)
        t=x.repeat(2, 1, 3) #shape=[4, 3, 12]
    
    if __name__=='__main__':
        change_tensor_shape()

    2.tensor的拼接:cat、stack

    除了要拼接的维度可以不相等,其他维度必须相等

    #coding=utf-8
    import  torch
    
    
    def cat_and_stack():
    
        x = torch.randn(2,3,6)
        y = torch.randn(2,4,6)
        c=torch.cat((x,y),1)
        #c=(2*7*6)
        print(c.size)
    
        """
        而stack则会增加新的维度。
        如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。
        """
        a = torch.rand((1, 2))
        b = torch.rand((1, 2))
        c = torch.stack((a, b), 0)
        print(c.size())
    
    if __name__=='__main__':
        cat_and_stack()

     3.压缩和扩展维度:改变tensor中只有1个维度的tensor

     torch.squeeze(input, dim=None, out=None) → Tensor

    除去输入张量input中数值为1的维度,并返回新的张量。如果输入张量的形状为(A×1×B×C×1×D) 那么输出张量的形状为(A×B×C×D)

    当通过dim参数指定维度时,维度压缩操作只会在指定的维度上进行。如果输入向量的形状为(A×1×B),
    squeeze(input, 0)会保持张量的维度不变,只有在执行squeeze(input, 1)时,输入张量的形状会被压缩至(A×B) 。

    如果一个张量只有1个维度,那么它不会受到上述方法的影响。

    #coding=utf-8
    import  torch
    
    
    def squeeze_tensor():
        x = torch.Tensor(1,3)
        y=torch.squeeze(x, 0)
        print("y:",y)
        y=torch.unsqueeze(y, 1)
        print("y:",y)
    
    if __name__=='__main__':
        squeeze_tensor()
  • 相关阅读:
    【Python入门自学笔记专辑】——面向对象编程-实例方法11.3.6
    最简单的轮播广告(原生JS)
    (转)JavaScript一:为什么学习JavaScript?
    JAVASCRIPT中经典面试题
    使用AngularJS实现简单:全选和取消全选功能
    canvas绘制经典折线图(一)
    前端总结
    PHP如何连接MySQL数据库
    PHP预定义变量
    PHP语法
  • 原文地址:https://www.cnblogs.com/AntonioSu/p/12021384.html
Copyright © 2011-2022 走看看