zoukankan      html  css  js  c++  java
  • pytorch基本运算:加减乘除、对数幂次等

    1、加减乘除

    • a + b = torch.add(a, b)
    • a - b = torch.sub(a, b)
    • a * b = torch.mul(a, b)
    • a / b = torch.div(a, b)
    import torch
    
    a = torch.rand(3, 4)
    b = torch.rand(4)
    a
    # 输出:
        tensor([[0.6232, 0.5066, 0.8479, 0.6049],
                [0.3548, 0.4675, 0.7123, 0.5700],
                [0.8737, 0.5115, 0.2106, 0.5849]])
    
    b
    # 输出:
        tensor([0.3309, 0.3712, 0.0982, 0.2331])
        
    # 相加
    # b会被广播
    a + b
    # 输出:
        tensor([[0.9541, 0.8778, 0.9461, 0.8380],
                [0.6857, 0.8387, 0.8105, 0.8030],
                [1.2046, 0.8827, 0.3088, 0.8179]])   
    # 等价于上面相加
    torch.add(a, b)
    # 输出:
        tensor([[0.9541, 0.8778, 0.9461, 0.8380],
                [0.6857, 0.8387, 0.8105, 0.8030],
                [1.2046, 0.8827, 0.3088, 0.8179]])  
    
    # 比较两个是否相等
    torch.all(torch.eq(a + b, torch.add(a, b)))
    # 输出:
        tensor(True)    
    

    2、矩阵相乘

    • torch.mm(a, b) # 此方法只适用于2维

    • torch.matmul(a, b)

    • a @ b = torch.matmul(a, b) # 推荐使用此方法

    • 用处:

      1. 降维:比如,[4, 784] @ [784, 512] = [4, 512]
      2. 大于2d的数据相乘:最后2个维度的数据相乘:[4, 3, 28, 64] @ [4, 3, 64, 32] = [4, 3, 28, 32]

        前提是:除了最后两个维度满足相乘条件以外,其他维度要满足广播条件,比如此处的前面两个维度只能是[4, 3]和[4, 1]
    a = torch.full((2, 2), 3)
    a
    # 输出
        tensor([[3., 3.],
                [3., 3.]])
    
    b = torch.ones(2, 2)
    b
    # 输出
        tensor([[1., 1.],
                [1., 1.]])
        
    torch.mm(a, b)
    # 输出
        tensor([[6., 6.],
                [6., 6.]])
    
    torch.matmul(a, b)
    # 输出
        tensor([[6., 6.],
                [6., 6.]])
        
    a @ b
    # 输出
        tensor([[6., 6.],
                [6., 6.]])    
    
    
    

    3、幂次计算

    • pow, sqrt, rsqrt
    a = torch.full([2, 2], 3)
    a
    # 输出
        tensor([[3., 3.],
                [3., 3.]])
        
    a.pow(2)
    # 输出
        tensor([[9., 9.],
                [9., 9.]])    
        
    aa = a ** 2
    aa
    # 输出
        tensor([[9., 9.],
                [9., 9.]]) 
        
    # 平方根
    aa.sqrt()
    # 输出
        tensor([[3., 3.],
                [3., 3.]])
    # 平方根    
    aa ** (0.5)
    # 输出
        tensor([[3., 3.],
                [3., 3.]])    
    # 平方根    
    aa.pow(0.5)
    # 输出
        tensor([[3., 3.],
                [3., 3.]])    
        
    # 平方根的倒数
    aa.rsqrt()
    # 输出
        tensor([[0.3333, 0.3333],
                [0.3333, 0.3333]])        
    
    tensor([[3., 3.],
            [3., 3.]])
    

    4、自然底数与对数

    a = torch.ones(2, 2)
    a
    # 输出
        tensor([[1., 1.],
                [1., 1.]])
        
    # 自认底数e
    torch.exp(a)
    # 输出
        tensor([[2.7183, 2.7183],
                [2.7183, 2.7183]])
    
    # 对数
    # 默认底数是e
    # 可以更换为Log2、log10
    torch.log(a)
    # 输出
    tensor([[0., 0.],
            [0., 0.]])    
    

    5、近似值

    • a.floor() # 向下取整:floor,地板
    • a.ceil() # 向上取整:ceil,天花板
    • a.trunc() # 保留整数部分:truncate,截断
    • a.frac() # 保留小数部分:fraction,小数
    • a.round() # 四舍五入:round,大约

    6、限幅

    • a.max() # 最大值
    • a.min() # 最小值
    • a.median() # 中位数
    • a.clamp(10) # 将最小值限定为10
      • a.clamp(0, 10) # 将数据限定在[0, 10],两边都是闭区间
  • 相关阅读:
    delphi 不规则窗体与桌面宠物
    delphi窗体透明但上面的控件不透明怎么实现
    IIS错误:在唯一密钥属性“fileExtension”设置为“.json”时,无法添加类型为“mimeMap”的重复集合项
    putty连接centos慢
    centos systemd占用大量内存
    laravel 163发送邮件设置及常见错误
    laravel 163发送邮件
    laravel npm run dev 错误 npm run dev error [npm ERR! code ELIFECYCLE]
    linux 如何指定nologin用户执行命令
    laravel Method IlluminateValidationValidator::validateReuqired does not exist.
  • 原文地址:https://www.cnblogs.com/jaysonteng/p/13040596.html
Copyright © 2011-2022 走看看