zoukankan      html  css  js  c++  java
  • PyTorch 两大转置函数 transpose() 和 permute()

    PyTorch 两大转置函数 transpose() 和 permute(), 以及RuntimeError: invalid argument 2: view size is not compati


    关心差别的可以直接看[3.不同点]和[4.连续性问题]
    前言
    在pytorch中转置用的函数就只有这两个

    transpose()
    permute()
    注意只有transpose()有后缀格式:transpose_():后缀函数的作用是简化如下代码:

    x = x.transpose(0,1)
    等价于
    x.transpose_()
    # 相当于x = x + 1 简化为 x+=1

    这两个函数都是交换维度的操作。有一些细微的区别

    1. 官方文档
    transpose()
    torch.transpose(input, dim0, dim1, out=None) → Tensor

    函数返回输入矩阵input的转置。交换维度dim0和dim1

    参数:

    input (Tensor) – 输入张量,必填
    dim0 (int) – 转置的第一维,默认0,可选
    dim1 (int) – 转置的第二维,默认1,可选
    permute()
    permute(dims) → Tensor

    将tensor的维度换位。

    参数:

    dims (int…*)-换位顺序,必填
    2. 相同点
    都是返回转置后矩阵。
    都可以操作高纬矩阵,permute在高维的功能性更强。
    3.不同点
    先定义我们后面用的数据如下

    # 创造二维数据x,dim=0时候2,dim=1时候3
    x = torch.randn(2,3) 'x.shape → [2,3]'
    # 创造三维数据y,dim=0时候2,dim=1时候3,dim=2时候4
    y = torch.randn(2,3,4) 'y.shape → [2,3,4]'

    合法性不同
    torch.transpose(x)合法, x.transpose()合法。
    tensor.permute(x)不合法,x.permute()合法。

    参考第二点的举例

    操作dim不同:
    transpose()只能一次操作两个维度;permute()可以一次操作多维数据,且必须传入所有维度数,因为permute()的参数是int*。

    举例

    # 对于transpose
    x.transpose(0,1) 'shape→[3,2] '
    x.transpose(1,0) 'shape→[3,2] '
    y.transpose(0,1) 'shape→[3,2,4]'
    y.transpose(0,2,1) 'error,操作不了多维'

    # 对于permute()
    x.permute(0,1) 'shape→[2,3]'
    x.permute(1,0) 'shape→[3,2], 注意返回的shape不同于x.transpose(1,0) '
    y.permute(0,1) "error 没有传入所有维度数"
    y.permute(1,0,2) 'shape→[3,2,4]'

    transpose()中的dim没有数的大小区分;permute()中的dim有数的大小区分
    举例,注意后面的shape:

    # 对于transpose,不区分dim大小
    x1 = x.transpose(0,1) 'shape→[3,2] '
    x2 = x.transpose(1,0) '也变换了,shape→[3,2] '
    print(torch.equal(x1,x2))
    ' True ,value和shape都一样'

    # 对于permute()
    x1 = x.permute(0,1) '不同transpose,shape→[2,3] '
    x2 = x.permute(1,0) 'shape→[3,2] '
    print(torch.equal(x1,x2))
    'False,和transpose不同'

    y1 = y.permute(0,1,2) '保持不变,shape→[2,3,4] '
    y2 = y.permute(1,0,2) 'shape→[3,2,4] '
    y3 = y.permute(1,2,0) 'shape→[3,4,2] '

    4.关于连续contiguous()
    经常有人用view()函数改变通过转置后的数据结构,导致报错
    RuntimeError: invalid argument 2: view size is not compatible with input tensor's....

    这是因为tensor经过转置后数据的内存地址不连续导致的,也就是tensor . is_contiguous()==False
    这时候reshape()可以改变该tensor结构,但是view()不可以,具体不同可以看view和reshape的区别
    例子如下:

    x = torch.rand(3,4)
    x = x.transpose(0,1)
    print(x.is_contiguous()) # 是否连续
    'False'
    # 再view会发现报错
    x.view(3,4)
    '''报错
    RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
    '''

    # 但是下面这样是不会报错。
    x = x.contiguous()
    x.view(3,4)

    我们再看看reshape()

    x = torch.rand(3,4)
    x = x.permute(1,0) # 等价x = x.transpose(0,1)
    x.reshape(3,4)
    '''这就不报错了
    说明x.reshape(3,4) 这个操作
    等于x = x.contiguous().view()
    尽管如此,但是torch文档中还是不推荐使用reshape
    理由是除非为了获取完全不同但是数据相同的克隆体
    '''

    调用contiguous()时,会强制拷贝一份tensor,让它的布局和从头创建的一毛一样。
    (这一段看文字你肯定不理解,你也可以不用理解,有空我会画图补上)

    只需要记住了,每次在使用view()之前,该tensor只要使用了transpose()和permute()这两个函数一定要contiguous().

    5.总结
    最重要的区别应该是上面的第三点和第四个。

    另外,简单的数据用transpose()就可以了,但是个人觉得不够直观,指向性弱了点;复杂维度的可以用permute(),对于维度的改变,一般更加精准。
    ————————————————
    版权声明:本文为CSDN博主「模糊包」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    原文链接:https://blog.csdn.net/xinjieyuan/article/details/105232802




    如果这篇文章帮助到了你,你可以请作者喝一杯咖啡

  • 相关阅读:
    多线程系列二:原子操作
    多线程系列一:线程基础
    java面试题——高级篇
    SpringMVC系列(十六)Spring MVC与Struts2的对比
    SpringMVC系列(十五)Spring MVC与Spring整合时实例被创建两次的解决方案以及Spring 的 IOC 容器和 SpringMVC 的 IOC 容器的关系
    Android IOS WebRTC 音视频开发总结(三六)-- easyRTC介绍
    Android IOS WebRTC 音视频开发总结(三五)-- chatroulette介绍
    Android IOS WebRTC 音视频开发总结(三四)-- windows.20150706
    Android IOS WebRTC 音视频开发总结(三三)-- Periscope介绍
    Android IOS WebRTC 音视频开发总结(三二)-- WebRTC项目开发建议
  • 原文地址:https://www.cnblogs.com/sddai/p/14548100.html
Copyright © 2011-2022 走看看