zoukankan      html  css  js  c++  java
  • pytorch常用函数总结(持续更新)

    pytorch常用函数总结(持续更新)

    torch.max(input,dim)

    求取指定维度上的最大值,,返回输入张量给定维度上每行的最大值,并同时返回每个最大值的位置索引。比如:

    demo.shape
    Out[7]: torch.Size([10, 3, 10, 10])
    
    torch.max(demo,1)[0].shape
    Out[8]: torch.Size([10, 10, 10])
    

    torch.max(demo,1)[0]这其中的[0]取得就是返回的最大值,torch.max(demo,1)[1]就是返回的最大值对应的位置索引。例子如下:

    a
    Out[8]: 
    tensor([[1., 2., 3.],
            [4., 5., 6.]])
    a.max(1)
    Out[9]: 
    torch.return_types.max(
    values=tensor([3., 6.]),
    indices=tensor([2, 2]))
    

    class torch.nn.ParameterList(parameters=None)

    submodules保存在一个list中。

    ParameterList可以像一般的Python list一样被索引。而且ParameterList中包含的parameters已经被正确的注册,对所有的module method可见。

    参数说明:

    • modules (list, optional) – a list of nn.Parameter

    例子:

    class MyModule(nn.Module):
        def __init__(self):
            super(MyModule, self).__init__()
            self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
    
        def forward(self, x):
            # ModuleList can act as an iterable, or be indexed using ints
            for i, p in enumerate(self.params):
                x = self.params[i // 2].mm(x) + p.mm(x)
            return x
    

    torch.cat()函数

    cat是concatnate的意思:拼接,联系在一起。

    先说cat( )的普通用法

    如果我们有两个tensor是A和B,想把他们拼接在一起,需要如下操作:

    C = torch.cat( (A,B),0 )  #按维数0拼接(竖着拼)
    
    C = torch.cat( (A,B),1 )  #按维数1拼接(横着拼)
    

    相当于将tensor按照指定维度进行拼接,比如A的shape为128*64*32*32,B的shape为 128*32*64*64,那么按照 torch.cat( (A,B),1)拼接的之后的形状为 128*96*64*64

    注意:

    两个tensor要想进行拼接,必须保证除了指定拼接的维度以外其他的维度形状必须相同,比如上面的例子,拼接A和B时,A的形状为128*64*32*32,B的形状为128*32*64*64,只有第二个维度的维数数值不同,其他的维度的维数都是相同的,所以拼接时可按维度1进行拼接(注意,维度的下标是从0开始的,比如 A 的形状对应的维度下标为:1280641322323128_0*64_1*32_2*32_3

    contiguous()函数的使用

    contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形(如:tensor_var.contiguous().view() ),示例如下:

    x = torch.Tensor(2,3)
    y = x.permute(1,0)         # permute:二维tensor的维度变换,此处功能相当于转置transpose
    y.view(-1)                 # 报错,view使用前需调用contiguous()函数
    y = x.permute(1,0).contiguous()
    y.view(-1)                 # OK
    

    具体原因有两种说法:

    1 transpose、permute等维度变换操作后,tensor在内存中不再是连续存储的,而view操作要求tensor的内存连续存储,所以需要contiguous来返回一个contiguous copy;

    2 维度变换后的变量是之前变量的浅拷贝,指向同一区域,即view操作会连带原来的变量一同变形,这是不合法的,所以也会报错;---- 这个解释有部分道理,也即contiguous返回了tensor的深拷贝contiguous copy数据;

    原文链接:https://zhuanlan.zhihu.com/p/64376950

    tensor.repeat()函数

    该函数传入的参数个数不少于tensor的维数,其中每个参数代表的是对该维度重复多少次,也就相当于复制的倍数,结合例子更好理解,如下:

    >>> import torch
    >>> 
    >>> a = torch.randn(33, 55)
    >>> a.size()
    torch.Size([33, 55])
    >>> 
    >>> a.repeat(1, 1).size()
    torch.Size([33, 55])
    >>> 
    >>> a.repeat(2,1).size()
    torch.Size([66, 55])
    >>> 
    >>> a.repeat(1,2).size()
    torch.Size([33, 110])
    >>>
    >>> a.repeat(1,1,1).size()
    torch.Size([1, 33, 55])
    >>>
    >>> a.repeat(2,1,1).size()
    torch.Size([2, 33, 55])
    >>>
    >>> a.repeat(1,2,1).size()
    torch.Size([1, 66, 55])
    >>>
    >>> a.repeat(1,1,2).size()
    torch.Size([1, 33, 110])
    >>>
    >>> a.repeat(1,1,1,1).size()
    torch.Size([1, 1, 33, 55])
    >>> 
    >>> # repeat()的参数的个数,不能少于被操作的张量的维度的个数,
    >>> # 下面是一些错误示例
    >>> a.repeat(2).size()  # 1D < 2D, error
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
    >>>
    >>> b = torch.randn(5,6,7)
    >>> b.size() # 3D
    torch.Size([5, 6, 7])
    >>> 
    >>> b.repeat(2).size() # 1D < 3D, error
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
    >>>
    >>> b.repeat(2,1).size() # 2D < 3D, error
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
    >>>
    >>> b.repeat(2,1,1).size() # 3D = 3D, okay
    torch.Size([10, 6, 7])
    >>>
    
    

    参考博客:https://blog.csdn.net/qq_29695701/article/details/89763168

    保持对优秀的热情
  • 相关阅读:
    Android控件Editext、TextView属性详解
    修改Android签名证书keystore的密码、别名alias以及别名密码
    android 中如何限制 EditText 最大输入字符数
    keytool 错误 java.io.IOException: incorrect AVA format
    Android打包常见错误之Export aborted because fatal lint errors were found
    正则表达式之判断用户注册信息是否为汉字、字母和数字
    Android Dialog 系统样式讲解及透明背景
    Android中自定义Activity和Dialog的位置大小背景和透明度等
    字体在Android View中的输出 drawText
    怎么用CIFilter给图片加上各种各样的滤镜_1
  • 原文地址:https://www.cnblogs.com/luckforefforts/p/13642681.html
Copyright © 2011-2022 走看看