zoukankan      html  css  js  c++  java
  • pytorch 常见函数理解

    gather

    >>> a = torch.Tensor([[1,2],[3,4]])
    >>> a
    tensor([[ 1.,  2.],
            [ 3.,  4.]])
    >>> torch.gather(a,1,torch.LongTensor([
    ... [0,0],
    ... [1,0]]))
    tensor([[ 1.,  1.],
            [ 4.,  3.]])
    #1代表按照第1维度进行计算
    #第一维也就是按照行,第一行[0,0]代表,新的tensor的第一行的两个元素,分别是a第一行的的第0个和第0个元素
    #第一维也就是按照行,第二行[1,0]代表,新的tensor的第二行的两个元素,分别是a第二行的第1个和第0个元素
    >>> torch.gather(a,0,torch.LongTensor([
    ... [0,0],
    ... [1,0]]))
    tensor([[ 1.,  2.],
            [ 3.,  2.]])
    #0代表按照第0维度进行计算
    #第0维也就是按照列,第二列[0,0]代表,新的tensor的第二列的两个元素,分别是a第二列的第0个和第0个元素

    squeeze 

    将维度为1的压缩掉。如size为(3,1,1,2),压缩之后为(3,2)

    import torch
    a=torch.randn(2,1,1,3)
    print(a)
    print(a.squeeze())

    输出:

    tensor([[[[-0.2320, 0.9513, 1.1613]]],
    
    
    [[[ 0.0901, 0.9613, -0.9344]]]])
    tensor([[-0.2320, 0.9513, 1.1613], [ 0.0901, 0.9613, -0.9344]])

      

    expand

    扩展某个size为1的维度。如(2,2,1)扩展为(2,2,3)

    import torch
    x=torch.randn(2,2,1)
    print(x)
    y=x.expand(2,2,3)
    print(y)


    输出:

    tensor([[[ 0.0608],
    [ 2.2106]],
    
    [[-1.9287],
    [ 0.8748]]])
    tensor([[[ 0.0608, 0.0608, 0.0608],
    [ 2.2106, 2.2106, 2.2106]],
    
    [[-1.9287, -1.9287, -1.9287],
    [ 0.8748, 0.8748, 0.8748]]])
    

      

    参考:https://blog.csdn.net/hbu_pig/article/details/81454503

    sum

    size为(m,n,d)的张量,dim=1时,输出为size为(m,d)的张量

    import torch
    a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
    print(a.sum())
    print(a.sum(dim=1))

    输出:

    tensor(60)
    tensor([[ 5, 10, 15],
    [ 5, 10, 15]])
    

      

    contiguous

    返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。可以通过is_contiguous查看张量内存是否连续。

    import torch
    a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
    print(a.is_contiguous)
    
    print(a.contiguous().view(4,3))

    输出:

    <built-in method is_contiguous of Tensor object at 0x7f4b5e35afa0>
    tensor([[ 1, 2, 3],
    [ 4, 8, 12],
    [ 1, 2, 3],
    [ 4, 8, 12]])
    

      

    softmax

    假设数组V有C个元素。对其进行softmax等价于将V的每个元素的指数除以所有元素的指数之和。这会使值落在区间(0,1)上,并且和为1。

    import torch
    import torch.nn.functional as F
    
    a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])
    b=F.softmax(a,dim=1)
    print(b)

    输出:

    tensor([[ 0.5000, 0.5000],
    [ 0.7311, 0.2689],
    [ 0.8808, 0.1192],
    [ 0.2689, 0.7311],
    [ 0.1192, 0.8808]])
    

    max

    返回最大值,或指定维度的最大值以及index

    import torch
    a=torch.tensor([[.1,.2,.3],
    [1.1,1.2,1.3],
    [2.1,2.2,2.3],
    [3.1,3.2,3.3]])
    print(a.max(dim=1))
    print(a.max())

    输出:

    (tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2]))
    tensor(3.3000)
    

      

    argmax

    返回最大值的index

    import torch
    a=torch.tensor([[.1,.2,.3],
    [1.1,1.2,1.3],
    [2.1,2.2,2.3],
    [3.1,3.2,3.3]])
    print(a.argmax(dim=1))
    print(a.argmax(dim=0))
    print(a.argmax())

    输出:

    tensor([ 2, 2, 2, 2])
    tensor([ 3, 3, 3])
    tensor(11)
    

      

      



  • 相关阅读:
    CF763C Timofey and Remoduling
    CF762E Radio Stations
    CF762D Maximum Path
    CF763B Timofey and Rectangles
    URAL1696 Salary for Robots
    uva10884 Persephone
    LA4273 Post Offices
    SCU3037 Painting the Balls
    poj3375 Network Connection
    Golang zip压缩文件读写操作
  • 原文地址:https://www.cnblogs.com/hozhangel/p/10030246.html
Copyright © 2011-2022 走看看