zoukankan      html  css  js  c++  java
  • tensor的数学运算

    tensor的数学运算


    & n b s p ;

    基础四则运算

    a + b ——torch.add()

    a - b ——torch.sub()

    a * b ——torch.mul()

    a / b ——torch.div()

    效果一致

    & n b s p ;

    矩阵相乘

    torch.mm:只适用于2D的矩阵相乘

    torch.matmul推荐使用

    @ (a@b)与matmul效果相同

    a = torch.ones(2,2)
    b = torch.ones(2,2)
    print(torch.matmul(a,b))
    
    #tensor([[2., 2.],
    #        [2., 2.]])
    

    实例:

    a = torch.rand(4,784)
    x = torch.rand(4,784)
    w = torch.rand(512,784)
    (x @  w.t()).shape
    #[4,512]
    

    X有4张照片,每张照片都打平了,我们现在希望进行一个降维的过程:【4,784】——>【4,512】

    所以我们在中间要构建一个w:【784,512】,通过x@w来实现降维

    但是这里的w是按照pytorch的写法:【channel_out(出去的维度),channel_in(进入的维度)】,所以要进行一个矩阵的转置。

    & n b s p ;

    二维以上的矩阵转置

    保持前两维不变,进行后两维的运算

    例子:

    a = torch.rand(4,3,28,64)
    b = torch.rand(4,3,64,32)
    torch.matmul(a,b).shape
    #[4,3,28,32]
    

    运用broadcst规则:

    a = torch.rand(4,3,28,64)
    b = torch.rand(4,1,64,32)
    torch.matmul(a,b).shape
    #[4,3,28,32]
    

    & n b s p ;

    power函数以及sqrt函数

    进行矩阵的平方运算以及开方运算

    a = torch.full([2,2],3)
    a.pow(2)
    #[[9,9],
    #[9,9]]
    

    aa = a.pow(2)——aa = a**2

    aa = a.sqrt(2)——aa = a**(0.5)

    & n b s p ;

    exp log函数

    exp函数:求自然指数

    log函数:求log(默认以e为底)

    log2:以2为底

    log10:以10为底

    a = torch.exp(torch.ones(2,2))
    #[[2.7183,2.7183],
    #[2.7183,2.7183]]
    
    torch.log(a)
    #[[1,1],
    #[1,1]]
    

    & n b s p ;

    近似值

    floor()向下取整

    ceil()向上取整

    trunc()取整数部分

    frac()取小数部分

    round()四舍五入

    例子:

    a = torch.tensor(3.14)
    print(a.trunc())
    print(a.frac())
    #tensor(3.)
    #tensor(0.1400)
    

    & n b s p ;

    clamp函数

    限制矩阵中的最大值与最小值

    例子:将矩阵中所有小于10的都变为10

    grad = torch.rand(2,3)*15
    print(grad)
    
    tensor([[13.3244, 11.6749, 14.0967],
            [13.3109,  7.9303,  8.3319]])
    
    temp = grad.clamp(10)
    print(temp)
    
    tensor([[13.3244, 11.6749, 14.0967],
            [13.3109, 10.0000, 10.0000]])
    

    grad.clamp(0,10)是将矩阵中的元素控制在(0,10)

    用处:在进行训练的时候通过这种方式来控制梯度的大小来防止梯度爆炸以及梯度消失

  • 相关阅读:
    插入数据失败提示: Setting autocommit to false on JDBC Connection 自动提交失败
    MyBatis XML配置properties
    mybatis 测试输出SQL语句到控制台配置
    原创:mysql5 还原至mysql 8.0.11数据库链接配置提示错误(修改内容有三处
    idea 快捷键汇总
    maven依赖配置和依赖范围
    pom.xml 配置 收藏
    单词的提取
    UVA10815 安迪的第一个字典 Andy's First Dictionary
    UVA11054 Gergovia的酒交易 Wine trading in Gergovia
  • 原文地址:https://www.cnblogs.com/Jason66661010/p/13600400.html
Copyright © 2011-2022 走看看