zoukankan      html  css  js  c++  java
  • Pytorch 常用函数

    1. torch.renorm(inputpdimmaxnormout=None) → Tensor

    Returns a tensor where each sub-tensor of input along dimension dim is normalized such that the p-norm of the sub-tensor is lower than the value maxnorm。

    解释:返回一个张量,包含规范化后的各个子张量,使得沿着dim维划分的各子张量的p范数小于maxnorm

    >>> x = torch.Tensor([[1,2,3]])
    >>> torch.renorm(x,2,0,1) tensor([[ 0.2673, 0.5345, 0.8018]])

    2. torch. scatter_(dimindexsrc) → Tensor

    src中的所有值按照index确定的索引写入本tensor中。其中索引是根据给定的dimension,dim按照gather()描述的规则来确定。

    注意,index的值必须是在0(self.size(dim)-1)之间,

    参数:

    • input (Tensor)-源tensor
    • dim (int)-索引的轴向
    • index (LongTensor)-散射元素的索引指数
    • src (Tensor or float)-散射的源元素
     1 >>> x = torch.rand(2, 5)
     2 >>> x
     3  0.4319  0.6500  0.4080  0.8760  0.2355
     4  0.2609  0.4711  0.8486  0.8573  0.1029
     5 [torch.FloatTensor of size 2x5]
    6 >>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) #将 x 按照格式写入新的Tensor里 7 0.4319 0.4711 0.8486 0.8760 0.2355 8 0.0000 0.6500 0.0000 0.8573 0.0000 9 0.2609 0.0000 0.4080 0.0000 0.1029 10 [torch.FloatTensor of size 3x5]
    11 >>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23) 12 >>> z 13 0.0000 0.0000 1.2300 0.0000 14 0.0000 0.0000 0.0000 1.2300 15 [torch.FloatTensor of size 2x4]

    3.  torch.gather(input, dim, index, out=None) Tensor

    沿给定轴dim,将输入索引张量index指定位置的值进行聚合。

    参数:

    • input (Tensor) – 源张量
    • dim (int) – 索引的轴
    • index (LongTensor) – 聚合元素的下标
    • out (Tensor, optional) – 目标张量
    >>> t = torch.Tensor([[1,2],[3,4]])
    >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
     1  1
     4  3
    [torch.FloatTensor of size 2x2]

     or:

    >>> s=torch.randn(3,6)
    >>> s
    tensor([[-0.4857, -0.0982, -0.6532, -1.0273, -0.9205, -0.7440],
            [-0.6890, -0.3474, -1.4337, -0.3511, -0.2443, -0.6398],
            [ 1.2902,  1.1210,  1.7374,  0.0902, -0.4524, -0.6898]])
    >>> s.gather(1,torch.LongTensor([[1,2,1],[1,2,3],[1,2,3]]))
    tensor([[-0.0982, -0.6532, -0.0982],
            [-0.3474, -1.4337, -0.3511],
            [ 1.1210,  1.7374,  0.0902]])

    4. pytorch改变维度的操作

    Pytorch Tensor维度变换

  • 相关阅读:
    js-AOP
    jQueryUI之autocomplete
    nginx安装配置
    oracle结构语法
    ajax/表单提交 多个相同name的处理方法
    ES6模块化
    docker运维
    帆软报表
    oracle锁表
    香港到大陆IPLC节点故障
  • 原文地址:https://www.cnblogs.com/king-lps/p/10726902.html
Copyright © 2011-2022 走看看