zoukankan      html  css  js  c++  java
  • PyTorch 两大转置函数 transpose() 和 permute(),

    pytorch中转置用的函数就只有这两个

    1. transpose()
    2. permute()
    transpose()
    torch.transpose(input, dim0, dim1, out=None) → Tensor

    函数返回输入矩阵input的转置。交换维度dim0dim1

    参数:

    • input (Tensor) – 输入张量,必填
    • dim0 (int) – 转置的第一维,默认0,可选
    • dim1 (int) – 转置的第二维,默认1,可选

    注意只能有两个相关的交换的位置参数。

    permute()
    参数:
    
    dims (int…*)-换位顺序,必填

    相同点

    1. 都是返回转置后矩阵。
    2. 都可以操作高纬矩阵,permute在高维的功能性更强。
    # 创造二维数据x,dim=0时候2,dim=1时候3
    x = torch.randn(2,3)       'x.shape  →  [2,3]'
    # 创造三维数据y,dim=0时候2,dim=1时候3,dim=2时候4
    y = torch.randn(2,3,4)   'y.shape  →  [2,3,4]'
    # 对于transpose
    x.transpose(0,1)     'shape→[3,2] '  
    x.transpose(1,0)     'shape→[3,2] '  
    y.transpose(0,1)     'shape→[3,2,4]' 
    y.transpose(0,2,1)  'error,操作不了多维'
    
    # 对于permute()
    x.permute(0,1)     'shape→[2,3]'
    x.permute(1,0)     'shape→[3,2], 注意返回的shape不同于x.transpose(1,0) '
    y.permute(0,1)     "error 没有传入所有维度数"
    y.permute(1,0,2)  'shape→[3,2,4]'
    
    
    合法性不同
    torch.transpose(x)合法, x.transpose()合法。
    tensor.permute(x)不合法,x.permute()合法。
    
    参考第二点的举例
    
    操作dim不同:
    transpose()只能一次操作两个维度;permute()可以一次操作多维数据,且必须传入所有维度数,因为permute()的参数是int*。
    1. transpose()中的dim没有数的大小区分;permute()中的dim有数的大小区分

    举例,注意后面的shape

    # 对于transpose,不区分dim大小
    x1 = x.transpose(0,1)   'shape→[3,2] '  
    x2 = x.transpose(1,0)   '也变换了,shape→[3,2] '  
    print(torch.equal(x1,x2))
    ' True ,value和shape都一样'
    
    # 对于permute()
    x1 = x.permute(0,1)     '不同transpose,shape→[2,3] '  
    x2 = x.permute(1,0)     'shape→[3,2] '  
    print(torch.equal(x1,x2))
    'False,和transpose不同'
    
    y1 = y.permute(0,1,2)     '保持不变,shape→[2,3,4] '  
    y2 = y.permute(1,0,2)     'shape→[3,2,4] '  
    y3 = y.permute(1,2,0)     'shape→[3,4,2] '  

    view()函数改变通过转置后的数据结构,导致报错
    RuntimeError: invalid argument 2: view size is not compatible with input tensor's....

    这是因为tensor经过转置后数据的内存地址不连续导致的,也就是tensor . is_contiguous()==False
    虽然在torch里面,view函数相当于numpy的reshape,但是这时候reshape()可以改变该tensor结构,但是view()不可以

    x = torch.rand(3,4)
    x = x.transpose(0,1)
    print(x.is_contiguous()) # 是否连续
    'False'
    # 会发现
    x.view(3,4)
    '''
    RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
    就是不连续导致的
    '''
    # 但是这样是可以的。
    x = x.contiguous()
    x.view(3,4)
    x = torch.rand(3,4)
    x = x.permute(1,0) # 等价x = x.transpose(0,1)
    x.reshape(3,4)
    '''这就不报错了
    说明x.reshape(3,4) 这个操作
    等于x = x.contiguous().view()
    尽管如此,但是我们还是不推荐使用reshape
    除非为了获取完全不同但是数据相同的克隆体
    '''

    调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。

    只需要记住了,每次在使用view()之前,该tensor只要使用了transpose()permute()这两个函数一定要contiguous().

  • 相关阅读:
    WPF程序国际化
    MVVM框架搭建
    最全前端开发面试问题及答案整理
    最小化运行批处理
    C#中App.config文件配置获取
    VS2017 打包成exe
    Inno Setup生成桌面快捷方式
    C#文件读写(txt 简单方式)
    Flume 学习笔记之 Flume NG概述及单节点安装
    快学Scala 第二十课 (trait的构造顺序)
  • 原文地址:https://www.cnblogs.com/tingtin/p/13547653.html
Copyright © 2011-2022 走看看