zoukankan      html  css  js  c++  java
  • 向量,矩阵和张量的导数

    https://zhuanlan.zhihu.com/p/142668996

    前段时间看过一些矩阵求导的教程,在看过的资料中,尤其喜欢斯坦福大学CS231n卷积神经网络课程中提到的Erik这篇文章。循着他的思路,可以逐步将复杂的求导过程简化、再简化,直到发现其中有规律的部分。话不多说,一起来看看吧。

     

    撰文 | Erik Learned-Miller

    翻译 | 写代码的橘子

    来源 | 橘子AI笔记(ID:datawitch)

     

    本文旨在帮助您学习向量、矩阵和高阶张量(三维或三维以上的数组)的求导方法,以及如何求对向量、矩阵和高阶张量的导数。

    01. 简化,简化,再简化

    在求关于数组的导数时,大部分困惑都源自于我们想要一次同时做好几件事。这“几件事”包括同时对多个元素求导、在求和符号下求导以及应用链式法则。至少在我们积累丰富的经验之前,想要同时做这么多件事情是很容易犯错的。

    1.1 写出矩阵中单个元素的表达式

    为了简化给定的计算,有一种方法是:写出输出中单个标量元素的表达式,这个表达式只包含标量变量。一旦写出了输出中单个标量元素与其他标量值的表达式,就可以使用标量的微积分求导方法,这比同时进行矩阵的求和、求导要容易得多。

    例子 假设我们有一个长度为C的列向量 [公式] ,它是由 [公式] 行 [公式] 列的矩阵 [公式] 与长度为 [公式] 的向量 [公式] 计算得到的:

    式(1)

    假设我们想求 [公式] 对 [公式] 的导数。完整的求导过程需要计算 [公式] 中的每一个元素对 [公式] 中的每一个元素的(偏)导数,在这种情况下,我们会算出 [公式] 个元素,因为 [公式] 中有 [公式] 个元素而 [公式] 中有 [公式] 个元素。

    让我们先从计算其中一个元素开始,比如, [公式] 中的第3个元素对 [公式] 中的第7个元素求导。也就是说,我们要计算

    也就是一个标量对另一个标量求导。

    在求导之前,我们要先写出 [公式] 的表达式。根据矩阵-向量乘法的定义,矩阵 [公式] 的第3行与向量 [公式] 的点积就是 [公式] 的值。

    式(2)

    此时,我们已经将原始矩阵方程式(1)简化为了一个标量方程,从而更容易计算所需的导数。

    1.2 去掉求和符号

    虽然我们可以尝试直接求式(2)的导数,但包含求和符号( [公式] )或连乘符号( [公式] )的表达式在求导时很容易出错。为了确保万无一失,在刚开始的时候最好去掉求和符号,把各项相加的表达式写出来。我们可以写出以下表达式,下标由“1”开始

    当然,这个表达式中包括了含有 [公式] 的项,这一项正是我们求导需要的项。现在不难看出,在求 [公式] 对 [公式] 的偏导数时,我们只关心这个表达式中的一项,[公式] 。由于其他项都不包括 [公式] ,他们对 [公式] 的导数都是0。由此,我们写出

    式(3)-式(6)

    通过把关注点放在y中的一个元素对x中的一个元素的求导过程,我们尽可能地简化了计算。以后当你在矩阵求导计算中产生困惑时,也可以试着将问题简化到这个最基本的程度,这样便于看清哪里出了问题。

    1.2.1 完成求导:雅可比矩阵

    别忘了,我们的终极目标是计算 [公式] 中每个元素对 [公式] 中每个元素的导数,这些导数总共有 [公式] 个。以下矩阵可以表示所有这些导数:

    在这种特殊情况下,它被称为雅可比矩阵(Jacobian maxtirx),但这个术语对理解我们的目的而言并不那么重要。

    注意,对于公式

    [公式] 对 [公式] 的偏导数可以简单地用 [公式] 来表示。如果挨个儿检查整个矩阵中的所有元素,就不难发现,对所有的 [公式] 和 [公式] 来说,都有

    也就是说,偏导数的矩阵可以表示为

    现在可以看出,这个矩阵当然就是矩阵 [公式] 本身。

    因此,推导了这么半天,我们终于能得出,对

    求 [公式] 对 [公式] 的导数相当于

    2. 如果是行向量该怎么算

    在使用不同的神经网络库时,留意权重矩阵、数据矩阵等矩阵的具体表达形式是非常重要的。例如,如果一个数据矩阵 [公式] 包含许多不同的向量,那么,在这个矩阵中,是一个行向量表示数据集中的一个样本,还是一个列向量表示一个样本?

    在第一部分的例子中,我们计算的向量 [公式] 是一个列向量。然而,当 [公式] 是行向量的时候你也得明白该怎么算。

    2.1 第二个例子

    假设 [公式] 是含有 [公式] 个元素的行向量,它是由含有 [公式] 个元素的行向量 [公式] 与 [公式] 行 [公式] 列的矩阵 [公式] 计算得到的:

    虽然 [公式] 和 [公式] 中的元素数量都和之前一样,但矩阵 [公式] 的形状相当于我们在第一个例子中使用的矩阵 [公式] 的转置(transpose)。尤其是因为我们现在是矩阵 [公式] 左乘 [公式] ,而不是之前的右乘,现在的矩阵 [公式] 必须是第一个例子中矩阵 [公式] 的转置。

    在这个例子中,写出 [公式] 的表达式

    会得到

    注意这个例子中的元素序号与第一个例子中相反。如果写出完整的雅可比矩阵,我们仍然可以得出

    式(7)

    3. 超过二维的情形该怎么算

    现在假设一个与前两部分密切相关的情形,如下式

    在这个情况下, [公式] 沿一个坐标轴变化,而 [公式] 沿两个坐标轴变化。因此,整个导数自然会是一个三维数组。在这里,我们避免使用“三维矩阵”这样的术语,因为尚不清楚矩阵乘法和其他矩阵运算在三维数组中是如何定义的。

    在处理三维数组的时候,尝试去找出展示它们的方法可能会带来不必要的麻烦。相反,我们应该简单地用表达式写出结果,用这些表达式可以计算出所需三维数组中的任何元素。

    让我们继续以标量导数的计算开始,比如y中的一个元素 [公式] 和 [公式] 中的一个元素 [公式] 。我们先用其他标量写出 [公式] 的表达式,这个表达式还要体现出 [公式] 在其计算中所起的作用。

    然而,我们发现 [公式] 在 [公式] 的计算中没有起到任何作用,因为

    式(8)

    也就是说

    不过, [公式] 对 [公式] 中第3列元素求导的结果一定是非零的。例如 [公式] 对 [公式] 的偏导数为

    式(9)

    其实仔细看式(8)就很容易发现这一点。

    一般情况下,当 [公式] 中元素的下标等于 [公式] 中元素的第二个下标时,这个偏导数就是非零的,反之则为零。我们由此写出:

    除此以外,三维数组中的其他元素都是0。如果用 [公式] 表示 [公式] 对 [公式] 求导得出的三维数组

    其中

    但是 [公式] 中的其他项都为0。

    最终,如果我们定义一个新的二维数组 [公式]

    就可以看出,我们需要的所有关于 [公式] 的信息实际上都可以用 [公式] 来储存,也就是说, [公式] 的非零部分其实是二维的,而不是三维的。

    以紧凑的形式表示导数数组对于神经网络的高效实现而言至关重要。

    4. 有多条数据该怎么算

    前面的例子已经是很好的求导练习了,但如果需要用到多条数据,也就是多个向量 [公式] 堆叠在一起构成矩阵 [公式] 时,又该如何计算呢?我们假设每个单独的 [公式] 都是一个长度为 [公式] 的行向量,矩阵 [公式] 是一个 [公式] 行 [公式] 列的二维数组。而矩阵 [公式] ,和之前的例子一样,是一个 [公式] 行 [公式] 列的矩阵。 [公式] 的定义如下

    它是一个 [公式] 行 [公式] 列的矩阵。因此, [公式] 的每一行将给出一个与输入 [公式] 的相应行相关的行向量。

    按照我们写出给定元素表达式的方法,可以写出

    我们马上就能从这个式子中看出,对于偏导数

    只有 [公式] 的时候计算结果才不为零。也就是说,因为 [公式] 中的每一个元素都只对 [公式] 中相应的那一行求导, [公式] 与 [公式] 的不同行之间的偏导数都为0。

    我们可以进一步发现

    式(10)

    完全不依赖于我们比较的是 [公式] 和 [公式] 的哪一行。

    事实上,矩阵 [公式] 完整包含了所有的偏导数——我们只需要根据式(10)和下标来找到我们想要的特定偏导数。

    如果用 [公式] 表示 [公式] 中的第 [公式] 行,用 [公式] 表示 [公式] 中的第 [公式] 行,可以发现

    正是对之前式(7)的一个简单的普遍化形式。

    5. 向量和矩阵中的链式法则

    我们已经通过几个例子学会了一些基本形式的计算,现在通过链式法则把这些例子结合在一起。再次假设 [公式] 和 [公式] 是两个列向量,让我们从下式开始

    尝试计算 [公式] 对 [公式] 的导数。我们可以简单地观察到两个矩阵 [公式] 和 [公式] 的乘积就是另一个矩阵 [公式] ,因此可以写出

    然而,我们想通过链式法则来定义中间结果,以观察在非标量求导过程中是如何应用链式法则的。

    我们把中间结果定义为

    于是有

    然后我们可以运用链式法则写出

    为了确保我们确切地知道该式的含义,再次采用每次分析一个元素的老办法,从 [公式] 中的一个元素和 [公式] 中的一个元素开始:

    右边的乘积该怎么解释呢?链式法则的思想是将 [公式] 对每个标量中间变量的导数与中间变量对 [公式] 的导数相乘。特别地,如果 [公式] 有 [公式] 个元素,那么可以写出

    回忆之前关于向量对向量求导的计算方法,发现

    其实是 [公式] ,而

    其实是 [公式] 。所以可以写出

    这就是用 [公式] 中的元素写出的求导表达式,至此我们得出了答案。

    综上所述,我们可以用链式法则来表示向量和矩阵的导数,只需要注意:

    • 清楚说明中间结果和表示中间结果的变量,
    • 表示出最终导数中各个元素的链式法则,
    • 对链式法则表达式中的中间结果适当求和。

     

     

    参考资料:
  • 相关阅读:
    Mixtile LOFT
    关于Linux系统清理/tmp/文件夹的原理
    在大型项目上,Python 是个烂语言吗
    Nginx 进程间通信
    蕤仁肉(内仁肉、泪仁肉)简单介绍
    TMS320F28335项目开发记录5_28335之CCS编程基础
    RBAC权限管理
    国内三大PT(Private Tracker)站分析
    Nginx特点
    java设计模式演示样例
  • 原文地址:https://www.cnblogs.com/dhcn/p/13476268.html
Copyright © 2011-2022 走看看