zoukankan      html  css  js  c++  java
  • 『PyTorch』矩阵乘法总结

    1. 二维矩阵乘法 torch.mm()

    torch.mm(mat1, mat2, out=None),其中mat1((n imes m)),mat2((m imes d)),输出out的维度是((n imes d))。

    该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。

    2. 三维带batch的矩阵乘法 torch.bmm()

    由于神经网络训练一般采用mini-batch,经常输入的时三维带batch的矩阵,所以提供torch.bmm(bmat1, bmat2, out=None),其中bmat1((b imes n imes m)),bmat2((b imes m imes d)),输出out的维度是((b imes n imes d))。

    该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。

    3. 多维矩阵乘法 torch.matmul()

    torch.matmul(input, other, out=None)支持broadcast操作,使用起来比较复杂。

    针对多维数据 matmul()乘法,我们可以认为该matmul()乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别是input((1000 imes 500 imes 99 imes 11)), other((500 imes 11 imes 99))那么我们可以认为torch.matmul(input, other, out=None)乘法首先是进行后两位矩阵乘法得到((99 imes 11) imes (11 imes 99)Rightarrow(99 imes 99)) ,然后分析两个参数的batch size分别是 (( 1000 imes 500))(500) , 可以广播成为 ((1000 imes 500)), 因此最终输出的维度是((1000 imes 500 imes 99 imes 99))。

    4. 矩阵逐元素(Element-wise)乘法 torch.mul()

    torch.mul(mat1, other, out=None),其中other乘数可以是标量,也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast的即可

    5. 两个运算符 @ 和 *

    • @:矩阵乘法,自动执行适合的矩阵乘法函数
    • *element-wise乘法
  • 相关阅读:
    Linux服务器使用tar加密压缩文件
    ssh-copy-id使用非默认22端口
    Nginx日志分割脚本
    MySQL的yum源
    vSphere Client开启虚拟机提示:出现了常规系统错误: 由于目标计算机积极拒绝,无法连接。
    ESXi主机遗忘密码重置密码
    扩容swap交换分区空间
    ESXi上的固态硬盘识别为非SSD
    VMware Vcenter Server 6.0忘记密码
    Centos6与Centos7区别
  • 原文地址:https://www.cnblogs.com/ice-coder/p/12951435.html
Copyright © 2011-2022 走看看