zoukankan      html  css  js  c++  java
  • pytorch基本运算:加减乘除、对数幂次等

    1、加减乘除

    • a + b = torch.add(a, b)
    • a - b = torch.sub(a, b)
    • a * b = torch.mul(a, b)
    • a / b = torch.div(a, b)
    import torch
    
    a = torch.rand(3, 4)
    b = torch.rand(4)
    a
    # 输出:
        tensor([[0.6232, 0.5066, 0.8479, 0.6049],
                [0.3548, 0.4675, 0.7123, 0.5700],
                [0.8737, 0.5115, 0.2106, 0.5849]])
    
    b
    # 输出:
        tensor([0.3309, 0.3712, 0.0982, 0.2331])
        
    # 相加
    # b会被广播
    a + b
    # 输出:
        tensor([[0.9541, 0.8778, 0.9461, 0.8380],
                [0.6857, 0.8387, 0.8105, 0.8030],
                [1.2046, 0.8827, 0.3088, 0.8179]])   
    # 等价于上面相加
    torch.add(a, b)
    # 输出:
        tensor([[0.9541, 0.8778, 0.9461, 0.8380],
                [0.6857, 0.8387, 0.8105, 0.8030],
                [1.2046, 0.8827, 0.3088, 0.8179]])  
    
    # 比较两个是否相等
    torch.all(torch.eq(a + b, torch.add(a, b)))
    # 输出:
        tensor(True)    
    

    2、矩阵相乘

    • torch.mm(a, b) # 此方法只适用于2维

    • torch.matmul(a, b)

    • a @ b = torch.matmul(a, b) # 推荐使用此方法

    • 用处:

      1. 降维:比如,[4, 784] @ [784, 512] = [4, 512]
      2. 大于2d的数据相乘:最后2个维度的数据相乘:[4, 3, 28, 64] @ [4, 3, 64, 32] = [4, 3, 28, 32]

        前提是:除了最后两个维度满足相乘条件以外,其他维度要满足广播条件,比如此处的前面两个维度只能是[4, 3]和[4, 1]
    a = torch.full((2, 2), 3)
    a
    # 输出
        tensor([[3., 3.],
                [3., 3.]])
    
    b = torch.ones(2, 2)
    b
    # 输出
        tensor([[1., 1.],
                [1., 1.]])
        
    torch.mm(a, b)
    # 输出
        tensor([[6., 6.],
                [6., 6.]])
    
    torch.matmul(a, b)
    # 输出
        tensor([[6., 6.],
                [6., 6.]])
        
    a @ b
    # 输出
        tensor([[6., 6.],
                [6., 6.]])    
    
    
    

    3、幂次计算

    • pow, sqrt, rsqrt
    a = torch.full([2, 2], 3)
    a
    # 输出
        tensor([[3., 3.],
                [3., 3.]])
        
    a.pow(2)
    # 输出
        tensor([[9., 9.],
                [9., 9.]])    
        
    aa = a ** 2
    aa
    # 输出
        tensor([[9., 9.],
                [9., 9.]]) 
        
    # 平方根
    aa.sqrt()
    # 输出
        tensor([[3., 3.],
                [3., 3.]])
    # 平方根    
    aa ** (0.5)
    # 输出
        tensor([[3., 3.],
                [3., 3.]])    
    # 平方根    
    aa.pow(0.5)
    # 输出
        tensor([[3., 3.],
                [3., 3.]])    
        
    # 平方根的倒数
    aa.rsqrt()
    # 输出
        tensor([[0.3333, 0.3333],
                [0.3333, 0.3333]])        
    
    tensor([[3., 3.],
            [3., 3.]])
    

    4、自然底数与对数

    a = torch.ones(2, 2)
    a
    # 输出
        tensor([[1., 1.],
                [1., 1.]])
        
    # 自认底数e
    torch.exp(a)
    # 输出
        tensor([[2.7183, 2.7183],
                [2.7183, 2.7183]])
    
    # 对数
    # 默认底数是e
    # 可以更换为Log2、log10
    torch.log(a)
    # 输出
    tensor([[0., 0.],
            [0., 0.]])    
    

    5、近似值

    • a.floor() # 向下取整:floor,地板
    • a.ceil() # 向上取整:ceil,天花板
    • a.trunc() # 保留整数部分:truncate,截断
    • a.frac() # 保留小数部分:fraction,小数
    • a.round() # 四舍五入:round,大约

    6、限幅

    • a.max() # 最大值
    • a.min() # 最小值
    • a.median() # 中位数
    • a.clamp(10) # 将最小值限定为10
      • a.clamp(0, 10) # 将数据限定在[0, 10],两边都是闭区间
  • 相关阅读:
    在IE和Firfox获取keycode
    using global variable in android extends application
    using Broadcast Receivers to listen outgoing call in android note
    help me!virtual keyboard issue
    using iscroll.js and iscroll jquery plugin in android webview to scroll div and ajax load data.
    javascript:jquery.history.js使用方法
    【CSS核心概念】弹性盒子布局
    【Canvas学习笔记】基础篇(二)
    【JS核心概念】数据类型以及判断方法
    【问题记录】ElementUI上传组件使用beforeupload钩子校验失败时的问题处理
  • 原文地址:https://www.cnblogs.com/jaysonteng/p/13040596.html
Copyright © 2011-2022 走看看