zoukankan      html  css  js  c++  java
  • pytorch中的cat、stack、tranpose、permute、unsqeeze

    Cat

    对数据沿着某一维度进行拼接。cat后数据的总维数不变.

    比如下面代码对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3还是2维的tensor。

    import torch

    torch.manual_seed(1)

    x = torch.randn(2,3)

    y = torch.randn(1,3)

    print(x,y)

    结果:

    0.6614 0.2669 0.0617

    0.6213 -0.4519 -0.1661

    [torch.FloatTensor of size 2x3]

    -1.5228 0.3817 -1.0276

    [torch.FloatTensor of size 1x3]

    将两个tensor拼在一起:

    torch.cat((x,y),0)

        结果:

        0.6614 0.2669 0.0617

        0.6213 -0.4519 -0.1661

        -1.5228 0.3817 -1.0276

        [torch.FloatTensor of size 3x3]

        stack,增加新的维度进行堆叠

        而stack则会增加新的维度。
        如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。
        见代码:

        a=torch.rand((1,2))
        b=torch.rand((1,2))

        c=torch.stack((a,b),0)

        c.size()


        结果:

        torch.Size([2, 1, 2])

        换成维度1:

        d=torch.stack((a,b),1)

        d.size()

        结果:

        torch.Size([1, 2, 2])

        transpose ,交换维度

        代码:

        torch.manual_seed(1)

        x = torch.randn(2,3)

        print(x)

        结果:

        0.6614 0.2669 0.0617

        0.6213 -0.4519 -0.1661

        [torch.FloatTensor of size 2x3]

        将x的维度互换:

        x.transpose(0,1)

        结果:
        0.6614 0.6213 
        0.2669 -0.4519 
        0.0617 -0.1661
         [torch.FloatTensor of size 3x2]

        permute,适合多维数据,更灵活的transpose

        permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。
        代码如下:

        x = torch.randn(2,3,4)

        print(x.size())

        x_p = x.permute(1,0,2) # 将原来第1维变为0维,同理,0→1,2→2 print(x_p.size())

        结果:

        torch.Size([2, 3, 4])

        torch.Size([3, 2, 4])

        squeeze 和 unsqueeze

        squeeze(dim_n)压缩,即去掉元素数量为1的dim_n维度。同理unsqueeze(dim_n),增加dim_n维度,元素数量为1。

        上代码:

        # 定义张量
        import torch
        
        b = torch.Tensor(2,1)
        b.shape
        Out[28]: torch.Size([2, 1])
        
        # 不加参数,去掉所有为元素个数为1的维度
        b_ = b.squeeze()
        b_.shape
        Out[30]: torch.Size([2])
        
        # 加上参数,去掉第一维的元素为1,不起作用,因为第一维有2个元素
        b_ = b.squeeze(0)
        b_.shape 
        Out[32]: torch.Size([2, 1])
        
        # 这样就可以了
        b_ = b.squeeze(1)
        b_.shape
        Out[34]: torch.Size([2])
        
        # 增加一个维度
        b_ = b.unsqueeze(2)
        b_.shape
        Out[36]: torch.Size([2, 1, 1])








          • 相关阅读:
            SpringBoot入门
            Java自定义注解(1)
            git集成idea
            git常用命令
            Shiro授权
            shiro认证
            shiro入门
            SpringMVC文件上传
            SpringMVC入门
            mybatis关联关系映射
          • 原文地址:https://www.cnblogs.com/yifdu25/p/9399047.html
          Copyright © 2011-2022 走看看