zoukankan      html  css  js  c++  java
  • pytorch函数

    1、squeeze()函数和unsqueeze()函数
    首先要知道tensor的维度,比如tensor([[0, 1, 2],[ 3, 4, 5]]),维度是(2, 3),当其维度变为(2, 1, 3)时,代表2个1行3列的矩阵,为tensor([[[0, 1, 2],[[3, 4, 5]]])。
    squeeze()函数就是减少一个维度,unsqueeze()函数就是增加一个维度。比如上述将(2, 3)变为(2, 1, 3)就是unsqueeze()操作。而squeeze()只有维数为1时才能去掉,去掉只需在括号中写入要去掉的维度数即可,如(2, 1, 3)去掉第1维变成(2, 3)。
    2、transpose()函数
    作用是交换维度,比如:

    x = torch.randn(2, 3)
    >>> x
    tensor([[ 1.0028, -0.9893,  0.5809],
            [-0.1669,  0.7299,  0.4942]])
    >>> torch.transpose(x, 0, 1)
    tensor([[ 1.0028, -0.1669],
            [-0.9893,  0.7299],
            [ 0.5809,  0.4942]])
    

    3、expand()和expand_as()函数
    expand()函数

    >>> x = torch.Tensor([[1], [2], [3]])
    >>> y = x.expand(3, 3)
    >>> print(x)
    tensor([[1.],
            [2.],
            [3.]])
    >>> print(y)
    tensor([[1., 1., 1.],
            [2., 2., 2.],
            [3., 3., 3.]])
    >>> print(x.shape)
    torch.Size([3, 1])
    >>> print(y.shape)
    torch.Size([3, 3])
    

    expand_as()函数:expand_as(tensor)将张量扩展为参数tensor的大小。

    >>> x = torch.randn(1, 3, 1, 1)
    >>> y = torch.randn(1, 3, 3, 3)
    >>> z = x.expand_as(y)
    >>> print(x)
    tensor([[[[ 0.4383]],
    
             [[-1.5909]],
    
             [[ 0.0814]]]])
    >>> print(z)
    tensor([[[[ 0.4383,  0.4383,  0.4383],
              [ 0.4383,  0.4383,  0.4383],
              [ 0.4383,  0.4383,  0.4383]],
    
             [[-1.5909, -1.5909, -1.5909],
              [-1.5909, -1.5909, -1.5909],
              [-1.5909, -1.5909, -1.5909]],
    
             [[ 0.0814,  0.0814,  0.0814],
              [ 0.0814,  0.0814,  0.0814],
              [ 0.0814,  0.0814,  0.0814]]]])
    

    4、permute()函数
    将tensor的维度换位
    比如图片img的size是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。

    比如Tensor([[[1,2,3],[4,5,6]]]),使用permute(0,2,1)可以将其转换成tensor([[[1., 4.], [2., 5.], [3., 6.]]])。
    

    5、argmax()函数和argmin()函数
    获取张量在某个维度的最大值和最小值的位置。
    argmax函数:torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号。dim代表该维度会消失,例如

    import torch
    t = torch.tensor([[1,2],[3,4],[2,8]])
    print(torch.argmax(t,0))
    
    g = torch.tensor([[[1,2,3],[2,3,4],[5,6,7]], [[3,4,5],[7,6,5],[5,4,3]], [[8,9,0],
                                [2,8,4],[7,5,3]]])
    print(g)
    print(torch.argmax(g,0))
    

    对于二维张量t来说,大小为(3, 2),使dim为0,意思是求第0维的最大值的序号,则固定行,直接看列,比较结果为tensor([1, 2])。
    对于三维张量g来说,大小为(3, 3, 3),使dim为0,则固定第一个维度,其余维度对应位置进行比较,得到结果为tensor([[2, 2, 1], [1, 2, 1], [2, 0, 0]])。
    6、numel()函数
    返回数组中元素的个数。例如:

    params = sum(p.numel() for p in list(net.parameters())) / 1e6 # numel()
    print('#Params: %.1fM' % (params))
    

    net.parameters():是Pytorch用法,用来返回net网络中的参数,而params则用来返回net网络中的参数的总数目。
    7、fit, transform, fit_transform

  • 相关阅读:
    下载安装安卓开发工具
    斐波那契数列
    各位相加
    求连续最大子序列和
    反转字符串中的单词
    统计位数为偶数的数字
    Express框架的整体感知
    node.js当中的http模块与url模块的简单介绍
    formidable处理提交的表单或文件的简单介绍
    30分钟,学会经典小游戏编程!
  • 原文地址:https://www.cnblogs.com/zyr001/p/14539689.html
Copyright © 2011-2022 走看看