zoukankan      html  css  js  c++  java
  • PyTorch【5】-Tensor 运算

    Tensor API 较多,所以把 运算 单独列出来,方便查看

    本教程环境 pytorch 1.3以上

    乘法

    t.mul(input, other, out=None):矩阵乘以一个数

    t.matmul(mat, mat, out=None):矩阵相乘

    t.mm(mat, mat, out=None):基本上等同于 matmul

    a=torch.randn(2,3)
    b=torch.randn(3,2)
    ### 等价操作
    print(torch.mm(a,b))        # mat x mat
    print(torch.matmul(a,b))    # mat x mat
    ### 等价操作
    print(torch.mul(a,3))       # mat 乘以 一个数
    print(a * 3)

    注意,乘法可以直接作用于单个数字

    乘法需要符合 向量乘法 的规则,即尺寸匹配

    a=torch.randn(2,3)
    c = torch.randn(2, 3)
    # print(torch.matmul(a, c))   # 尺寸不符合向量乘法,(2,3)x(2,3)
    print(torch.matmul(a, c.t())) # t() 转置,正确 (2,3)x(3,2)

    加法

    加法有 3 种方式:+,add,add_

    import torch as t
    y = t.rand(2, 3)        ### 使用[0,1]均匀分布构建矩阵
    z = t.ones(2, 3)        ### 2x3 的全 1 矩阵
    
    #### 3 中加法操作等价
    print(y + z)            ### 加法1
    t.add(y, z)             ### 加法2
    ### 加法的第三种写法
    result = t.Tensor(2, 3) ### 预先分配空间
    t.add(y, z, out=result) ### 指定加法结果的输出目标
    print(result)

    add_ 与 add 的区别在于,add 不会改变原来的 tensor,而 add_会改变原来的 tensor;

    在 pytorch 中,方法后面加  _ 都会改变原来的对象,相当于 in-place 的作用

    print(y)
    # tensor([[0.4083, 0.3017, 0.9511],
    #         [0.4642, 0.5981, 0.1866]])
    y.add(z)
    print(y)                ### y 不变
    # tensor([[0.4083, 0.3017, 0.9511],
    #         [0.4642, 0.5981, 0.1866]])
    y.add_(z)
    print(y)                ### y 变了,相当于 inplace
    # tensor([[1.4083, 1.3017, 1.9511],
    #         [1.4642, 1.5981, 1.1866]])

    可以作用于单个数字或者 尺寸为 (1,1) 的 Tensor

    a = t.ones(3, 3)
    print(a + 1)        ### 可以直接作用于单个数字
    
    b = t.ones(1, 1)
    print(a + b)
    
    c = t.ones(2, 1)
    # print(a + c)        ### 报错,如果尺寸不匹配,c 的尺寸只能是 (1, 1)

    减法 

    和加法一样,三种:-、sub、sub_

    a = t.randn(2, 1)
    b = t.randn(2, 1)
    print(a)
    ### 等价操作
    print(a - b)
    print(t.sub(a, b))
    print(a)        ### sub 后 a 没有变化
    
    a.sub_(b)
    print(a)        ### sub_ 后 a 也变了
    
    c = 1
    print(a - c)    ### 直接作用于单个数字

    其他运算

    t.div(input, other, out=None):除法

    t.pow(input, other, out=None):指数

    t.sqrt(input, out=None):开方

    t.round(input, out=None):四舍五入到整数

    t.abs(input, out=None):绝对值

    t.ceil(input, out=None):向上取整

    t.clamp(input, min, max, out=None):把 input 规范在 min 到 max 之间,超出用 min 和 max 代替,可理解为削尖函数

    t.argmax(input, dim=None, keepdim=False):返回指定维度最大值的索引

    t.sigmoid(input, out=None)

    t.tanh(input, out=None)

    参考资料:

  • 相关阅读:
    Python基础篇(七)
    RMI基础
    Python基础篇(五)
    装饰模式
    一些linux知识和http知识
    mysql统计一个库里面的表的总数
    关于phpmailer邮件发送
    Jenkins是什么?
    Android开发——JVM、Dalvik以及ART的区别【转帖】
    好记性不如烂笔头--linux学习笔记9练手写个shell脚本
  • 原文地址:https://www.cnblogs.com/yanshw/p/12206849.html
Copyright © 2011-2022 走看看