zoukankan      html  css  js  c++  java
  • pytorch安装与入门(二)--tensor的计算操作

    Math operation 数学运算

    1. Add/minus/multiply/divide
    2. Matmul
    3. Pow
    4. Sqrt/rsqrt
    5. Round

    加减乘除

    >>> a=torch.rand(3,4)
    >>> b=torch.rand(3)
    >>> a+b
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
    RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1
    >>>
    >>> b=torch.rand(4)
    >>> a+b
    tensor([[1.0229, 1.0478, 1.5048, 0.1652],
            [0.7158, 1.5854, 0.9578, 1.0546],
            [0.8643, 1.6602, 1.1447, 0.4042]])
    >>> a-b
    tensor([[ 0.8146, -0.9272, -0.3113, -0.1604],
            [ 0.5075, -0.3895, -0.8584,  0.7290],
            [ 0.6560, -0.3147, -0.6715,  0.0786]])
    >>> torch.all(torch.eq(a-b,torch.sub(a,b)))
    tensor(True)
    >>>
    >>> torch.all(torch.eq(a+b,torch.add(a,b)))
    tensor(True)
    >>> torch.all(torch.eq(a*b,torch.mul(a,b)))
    tensor(True)
    >>> torch.all(torch.eq(a/b,torch.div(a,b)))
    tensor(True)

    matmul矩阵乘积

     只能对 2d张量(矩阵)进行使用。

    >>> a
    tensor([[0.5370, 0.6411],
            [0.6101, 0.4873]])
    >>> b
    tensor([[0.6591, 0.9531],
            [0.9422, 0.0774]])
    >>> a*b
    tensor([[0.3539, 0.6110],
            [0.5749, 0.0377]])
    >>> a@b
    tensor([[0.9580, 0.5614],
            [0.8612, 0.6192]])
    >>> torch.all(torch.eq(torch.matmul(a,b),a@b))
    tensor(True)

    pow指数

    >>> torch.full([2,2],3)
    tensor([[3., 3.],
            [3., 3.]])
    >>> torch.full([2,2],3)**3
    tensor([[27., 27.],
            [27., 27.]])
    >>> torch.full([2,2],3).pow(2)
    tensor([[9., 9.],
            [9., 9.]])
    >>> torch.full([2,2],3).pow(2).sqrt()
    tensor([[3., 3.],
            [3., 3.]])
    >>> torch.full([2,2],3).pow(2).rsqrt()
    tensor([[0.3333, 0.3333],
            [0.3333, 0.3333]])

    note:rsqrt倒数平方根

    exp log

    >>> torch.exp(torch.full([2,2],1))
    tensor([[2.7183, 2.7183],
            [2.7183, 2.7183]])
    
    >>> torch.log((torch.exp(torch.full([2,2],1))))
    tensor([[1., 1.],
            [1., 1.]])

    Approximation

    • floor() .ceil()
    • round()
    • trunc() .frac()
    >>> a=torch.tensor(3.14)
    >>> a.floor(),a.ceil(),a.trunc(),a.round(),a.frac()
    (tensor(3.), tensor(4.), tensor(3.), tensor(3.), tensor(0.1400))

    clamp
    gradient clipping

    >>> a=torch.rand(2,2)*14
    >>> a
    tensor([[ 1.3472,  5.9060],
            [12.0558,  4.2571]])
    >>> a.clamp(10)
    tensor([[10.0000, 10.0000],
            [12.0558, 10.0000]])
    >>> a.clamp(0,10)
    tensor([[ 1.3472,  5.9060],
            [10.0000,  4.2571]])

    statistics 统计属性

    ▪ norm
    ▪ mean sum
    ▪ prod
    ▪ max, min, argmin, argmax
    ▪ kthvalue, topk

  • 相关阅读:
    牛客网·剑指offer 从尾到头打印链表(JAVA)
    牛客网·剑指offer 替换空格(JAVA)
    简单的用户登录后台程序编写
    牛客网&剑指offer 二维数组中的查找(JAVA)
    洛谷 P1603 斯诺登的密码(JAVA)
    【回溯法】八皇后问题(递归和非递归)
    如何使用SecureCRT让Vim有颜色?
    js 转base64字符串为文件
    springboot 测试类
    oracle 登录、重启服务
  • 原文地址:https://www.cnblogs.com/wqbin/p/12693355.html
Copyright © 2011-2022 走看看