zoukankan      html  css  js  c++  java
  • 【小白学PyTorch】10 pytorch常见运算详解

    参考目录:

    这一课主要是讲解PyTorch中的一些运算,加减乘除这些,当然还有矩阵的乘法这些。这一课内容不多,作为一个知识储备。在后续的内容中,有用PyTorch来获取EfficientNet预训练模型以及一个猫狗给分类的实战任务教学。

    加减乘除就不多说了,+-*/

    1 矩阵与标量

    这个是矩阵(张量)每一个元素与标量进行操作。

    import torch
    a = torch.tensor([1,2])
    print(a+1)
    >>> tensor([2, 3])
    

    2 哈达玛积

    这个就是两个相同尺寸的张量相乘,然后对应元素的相乘就是这个哈达玛积,也成为element wise。

    a = torch.tensor([1,2])
    b = torch.tensor([2,3])
    print(a*b)
    print(torch.mul(a,b))
    >>> tensor([2, 6])
    >>> tensor([2, 6])
    

    这个torch.mul()*是等价的。

    当然,除法也是类似的:

    a = torch.tensor([1.,2.])
    b = torch.tensor([2.,3.])
    print(a/b)
    print(torch.div(a/b))
    >>> tensor([0.5000, 0.6667])
    >>> tensor([0.5000, 0.6667])
    

    我们可以发现的torch.div()其实就是/, 类似的:torch.add就是+,torch.sub()就是-,不过符号的运算更简单常用。

    3 矩阵乘法

    如果我们想实现线性代数中的矩阵相乘怎么办呢?

    这样的操作有三个写法:

    • torch.mm()
    • torch.matmul()
    • @,这个需要记忆,不然遇到这个可能会挺蒙蔽的
    a = torch.tensor([1.,2.])
    b = torch.tensor([2.,3.]).view(1,2)
    print(torch.mm(a, b))
    print(torch.matmul(a, b))
    print(a @ b)
    

    输出结果:

    tensor([[2., 3.],
            [4., 6.]])
    tensor([[2., 3.],
            [4., 6.]])
    tensor([[2., 3.],
            [4., 6.]])
    

    这是对二维矩阵而言的,假如参与运算的是一个多维张量,那么只有torch.matmul()可以使用。等等,多维张量怎么进行矩阵的惩罚?在多维张量中,参与矩阵运算的其实只有后两个维度,前面的维度其实就像是索引一样,举个例子:

    a = torch.rand((1,2,64,32))
    b = torch.rand((1,2,32,64))
    print(torch.matmul(a, b).shape)
    >>> torch.Size([1, 2, 64, 64])
    

    可以看到,其实矩阵乘法的时候,看后两个维度:(64 imes 32) 乘上 (32 imes 64),得到一个(64 imes 64)的矩阵。前面的维度要求相同,像是索引一样,决定哪两个(64 imes 32)(32 imes 64)相乘。

    小提示:

    a = torch.rand((3,2,64,32))
    b = torch.rand((1,2,32,64))
    print(torch.matmul(a, b).shape)
    >>> torch.Size([3, 2, 64, 64])
    

    这样也是可以相乘的,因为这里涉及一个自动传播Broadcasting机制,这个在后面会讲,这里就知道,如果这种情况下,会把b的第一维度复制3次 ,然后变成和a一样的尺寸,进行矩阵相乘。

    4 幂与开方

    print('幂运算')
    a = torch.tensor([1.,2.])
    b = torch.tensor([2.,3.])
    c1 = a ** b
    c2 = torch.pow(a, b)
    print(c1,c2)
    >>> tensor([1., 8.]) tensor([1., 8.])
    

    和上面一样,不多说了。
    开方运算可以用torch.sqrt(),当然也可以用a**(0.5)。

    5 对数运算

    在上学的时候,我们知道ln是以e为底的,但是在pytorch中,并不是这样

    pytorch中log是以e自然数为底数的,然后log2和log10才是以2和10为底数的运算。

    import numpy as np
    print('对数运算')
    a = torch.tensor([2,10,np.e])
    print(torch.log(a))
    print(torch.log2(a))
    print(torch.log10(a))
    >>> tensor([0.6931, 2.3026, 1.0000])
    >>> tensor([1.0000, 3.3219, 1.4427])
    >>> tensor([0.3010, 1.0000, 0.4343]) 
    

    6 近似值运算

    • .ceil() 向上取整
    • .floor()向下取整
    • .trunc()取整数
    • .frac()取小数
    • .round()四舍五入
    a = torch.tensor(1.2345)
    print(a.ceil())
    >>>tensor(2.)
    print(a.floor())
    >>> tensor(1.)
    print(a.trunc())
    >>> tensor(1.)
    print(a.frac())
    >>> tensor(0.2345)
    print(a.round())
    >>> tensor(1.)
    

    7 剪裁运算

    这个是让一个数,限制在你自己设置的一个范围内[min,max],小于min的话就被设置为min,大于max的话就被设置为max。这个操作在一些对抗生成网络中,好像是WGAN-GP,通过强行限制模型的参数的值。

    a = torch.rand(5)
    print(a)
    print(a.clamp(0.3,0.7))
    

    输出为:

    tensor([0.5271, 0.6924, 0.9919, 0.0095, 0.0340])
    tensor([0.5271, 0.6924, 0.7000, 0.3000, 0.3000])
    
    人不可傲慢。
  • 相关阅读:
    关于For循环的性能
    CLR读书笔记
    轻量级自动化测试框架介绍
    loadrunner中如何将MD5加密的值转换为大写
    LoadRunner 中实现MD5加密
    新安装的soapui启动时报错及解决方法
    单元测试之驱动模块和桩模块的作用和区别
    接口自动化(Python)-利用正则表达式从返回的HTML文本中截取自己想要的值
    LoadRunner性能测试-loadrunner事务
    LoadRunner性能测试-loadrunner工具破解
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/13669628.html
Copyright © 2011-2022 走看看