zoukankan      html  css  js  c++  java
  • Pytorch-tensor维度的扩展,挤压,扩张

    数据本身不发生改变,数据的访问方式发生了改变

    1.维度的扩展

    函数:unsqueeze()

    # a是一个4维的
        a = torch.randn(4, 3, 28, 28)
        print('a.shape
    ', a.shape)
    
        print('
    维度扩展(变成5维的):')
        print('第0维前加1维')
        print(a.unsqueeze(0).shape)
        print('第4维前加1维')
        print(a.unsqueeze(4).shape)
        print('在-1维前加1维')
        print(a.unsqueeze(-1).shape)
        print('在-4维前加1维')
        print(a.unsqueeze(-4).shape)
        print('在-5维前加1维')
        print(a.unsqueeze(-5).shape)
    

    输出结果

    a.shape
     torch.Size([4, 3, 28, 28])
    
    维度扩展(变成5维的):
    第0维前加1维
    torch.Size([1, 4, 3, 28, 28])
    第4维前加1维
    torch.Size([4, 3, 28, 28, 1])
    在-1维前加1维
    torch.Size([4, 3, 28, 28, 1])
    在-4维前加1维
    torch.Size([4, 1, 3, 28, 28])
    在-5维前加1维
    torch.Size([1, 4, 3, 28, 28])
    

    注意,第5维前加1维,就会出错

        # print(a.unsqueeze(5).shape)
        # Errot:Dimension out of range (expected to be in range of -5, 4], but got 5)
    

    连续扩维:unsqueeze()

        # b是一个1维的
        b = torch.tensor([1.2, 2.3])
        print('b.shape
    ', b.shape)
        print()
        # 0维之前插入1维,变成1,2]
        print(b.unsqueeze(0))
        print()
        # 1维之前插入1维,变成2,1]
        print(b.unsqueeze(1))
    
        # 连续扩维,然后再对某个维度进行扩张
        print(b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape)
    

    输出结果

    b.shape
     torch.Size([2])
    
    tensor([[1.2000, 2.3000]])
    
    tensor([[1.2000],
            [2.3000]])
    torch.Size([1, 2, 1, 1])
    

    2.挤压维度

    函数:squeeze()

        # 挤压维度,只会挤压shape为1的维度,如果shape不是1的话,当前值就不会变
        c = torch.randn(1, 32, 1, 2)
        print(c.shape)
        print(c.squeeze(0).shape)
        print(c.squeeze(1).shape)  # shape不是1,不会变
        print(c.squeeze(2).shape)
        print(c.squeeze(3).shape)  # shape不是1,不会变
    

    输出结果

    torch.Size([1, 32, 1, 2])
    torch.Size([32, 1, 2])
    torch.Size([1, 32, 1, 2])
    torch.Size([1, 32, 2])
    torch.Size([1, 32, 1, 2])
    

    3.维度扩张

    函数1:expand():扩张到多少,

        # shape的扩张
        # expand():对shape为1的进行扩展,对shape不为1的只能保持不变,因为不知道如何变换,会报错
    
        d = torch.randn(1, 32, 1, 1)
        print(d.shape)
        print(d.expand(4, 32, 14, 14).shape)
    

    输出结果

    torch.Size([1, 32, 1, 1])
    torch.Size([4, 32, 14, 14])
    

    函数2:repeat()方法,扩张多少倍

        d=torch.randn([1,32,4,5])
        print(d.shape)
        print(d.repeat(4,32,2,3).shape)
    

    输出结果

    torch.Size([1, 32, 4, 5])
    torch.Size([4, 1024, 8, 15])
    
  • 相关阅读:
    Atom + activate-power-mode震屏插件Windows7下安装
    通过Google身份验证器加强Linux帐户安全
    adb 常用命令总结
    excel 文件加密
    docker 进入容器命令行 /bin/bash 后不支持中文
    无法获取 gcr.io 上的镜像的解决方法
    mysql unix 时间戳转换
    docker 镜像如何导入导出以及建立自己的镜像仓库
    asp.net core 文件的处理
    docker compose 设置环境变量
  • 原文地址:https://www.cnblogs.com/52dxer/p/13771279.html
Copyright © 2011-2022 走看看