zoukankan      html  css  js  c++  java
  • pytorch中的unsqueeze函数和squeeze函数

    在pytorch中,我们经常对张量Tensor的维度进行压缩或者扩充(压缩或者扩充的维度为1)。其中经常使用的是squeeze()函数和unsqueeze函数;
    squeeze在英文中的意思就是“挤、压”,所以故名思议,squeeze()函数就是对张量的维度进行减少的操作,话不多说,我们直接看下例子:

    import torch
    #定义两个整型的张量a,b
    a = torch.IntTensor([[1,2,3],[4,5,6]])
    b = torch.IntTensor([[[1,2,3],[4,5,6]]])
    #看一下a,b的形状
    print(a.shape)
    print(b.shape)
    '''
    ===output===
    torch.Size([2, 3])
    torch.Size([1, 2, 3])
    '''
    
    #我们看到张量b比较膨胀,有三个维度:1*2*3,所以我们要挤压一下张量b的第0个维度(因为是1才能挤压,否则没有效果)
    c = torch.squeeze(b,0)  # 对应的维度为第0维
    print(c.shape)
    '''
    ===output===
    torch.Size([2, 3])
    '''
    #那如果想想张量a膨胀一下,怎么办
    c = torch.unsqueeze(a,0)
    print(c.shape)
    '''
    ===output===
    torch.Size([1, 2, 3])
    '''
    #可以看到张量a在第0维也膨胀了, 如果你看不惯的话,再压缩一下它。
    

    另外,squeeze()函数和unsqueeze()函数还有另一种写法,直接用张量类型的变量来调用这两个函数:

    c = a.unsqueeze(0)
    print(c.shape)
    '''
    ===output===
    torch.Size([1, 2, 3])
    '''
    

    你看出差别了么?这里直接用张量变量a来调用了unsqueeze()函数,当然squeeze()也是一样的,不信你可以试试^_^

  • 相关阅读:
    oracle-报错 RMAN-03002,RMAN-06172
    oracle--报错 ORA-01003,ORA-09817,ORA-01075
    oracle--报错 ORA-00257
    Linux-iostat命令
    oracle--查询速度慢
    linux-根目录添加内存
    mysq-5.7忘记密码修改
    zsh: command not found cnpm,gulp等命令在zsh终端上报错的问题
    vue中的js引入图片,必须require进来
    如何启动一个Vue3.x项目
  • 原文地址:https://www.cnblogs.com/datasnail/p/13086803.html
Copyright © 2011-2022 走看看