zoukankan      html  css  js  c++  java
  • pytorch安装与入门(二)--tensor的计算操作

    Math operation 数学运算

    1. Add/minus/multiply/divide
    2. Matmul
    3. Pow
    4. Sqrt/rsqrt
    5. Round

    加减乘除

    >>> a=torch.rand(3,4)
    >>> b=torch.rand(3)
    >>> a+b
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1
    >>>
    >>> b=torch.rand(4)
    >>> a+b
    tensor([[1.0229, 1.0478, 1.5048, 0.1652],
            [0.7158, 1.5854, 0.9578, 1.0546],
            [0.8643, 1.6602, 1.1447, 0.4042]])
    >>> a-b
    tensor([[ 0.8146, -0.9272, -0.3113, -0.1604],
            [ 0.5075, -0.3895, -0.8584,  0.7290],
            [ 0.6560, -0.3147, -0.6715,  0.0786]])
    >>> torch.all(torch.eq(a-b,torch.sub(a,b)))
    tensor(True)
    >>>
    >>> torch.all(torch.eq(a+b,torch.add(a,b)))
    tensor(True)
    >>> torch.all(torch.eq(a*b,torch.mul(a,b)))
    tensor(True)
    >>> torch.all(torch.eq(a/b,torch.div(a,b)))
    tensor(True)

    matmul矩阵乘积

     只能对 2d张量(矩阵)进行使用。

    >>> a
    tensor([[0.5370, 0.6411],
            [0.6101, 0.4873]])
    >>> b
    tensor([[0.6591, 0.9531],
            [0.9422, 0.0774]])
    >>> a*b
    tensor([[0.3539, 0.6110],
            [0.5749, 0.0377]])
    >>> a@b
    tensor([[0.9580, 0.5614],
            [0.8612, 0.6192]])
    >>> torch.all(torch.eq(torch.matmul(a,b),a@b))
    tensor(True)

    pow指数

    >>> torch.full([2,2],3)
    tensor([[3., 3.],
            [3., 3.]])
    >>> torch.full([2,2],3)**3
    tensor([[27., 27.],
            [27., 27.]])
    >>> torch.full([2,2],3).pow(2)
    tensor([[9., 9.],
            [9., 9.]])
    >>> torch.full([2,2],3).pow(2).sqrt()
    tensor([[3., 3.],
            [3., 3.]])
    >>> torch.full([2,2],3).pow(2).rsqrt()
    tensor([[0.3333, 0.3333],
            [0.3333, 0.3333]])

    note:rsqrt倒数平方根

    exp log

    >>> torch.exp(torch.full([2,2],1))
    tensor([[2.7183, 2.7183],
            [2.7183, 2.7183]])
    
    >>> torch.log((torch.exp(torch.full([2,2],1))))
    tensor([[1., 1.],
            [1., 1.]])

    Approximation

    • floor() .ceil()
    • round()
    • trunc() .frac()
    >>> a=torch.tensor(3.14)
    >>> a.floor(),a.ceil(),a.trunc(),a.round(),a.frac()
    (tensor(3.), tensor(4.), tensor(3.), tensor(3.), tensor(0.1400))

    clamp
    gradient clipping

    >>> a=torch.rand(2,2)*14
    >>> a
    tensor([[ 1.3472,  5.9060],
            [12.0558,  4.2571]])
    >>> a.clamp(10)
    tensor([[10.0000, 10.0000],
            [12.0558, 10.0000]])
    >>> a.clamp(0,10)
    tensor([[ 1.3472,  5.9060],
            [10.0000,  4.2571]])

    statistics 统计属性

    ▪ norm
    ▪ mean sum
    ▪ prod
    ▪ max, min, argmin, argmax
    ▪ kthvalue, topk

  • 相关阅读:
    bzoj 3747: [POI2015]Kinoman
    bzoj 3123: [Sdoi2013]森林
    bzoj 1901: Zju2112 Dynamic Rankings
    poj 1741 Tree
    bzoj 2152: 聪聪可可
    bzoj 2599: [IOI2011]Race
    bzoj 3697: 采药人的路径
    bzoj 2728: [HNOI2012]与非
    bzoj 2115: [Wc2011] Xor
    bzoj 3143: [Hnoi2013]游走
  • 原文地址:https://www.cnblogs.com/wqbin/p/12693355.html
Copyright © 2011-2022 走看看