zoukankan      html  css  js  c++  java
  • $infty$-former: Infinite Memory Transformer

    Martins P., Marinho Z. and Martins A. (infty)-former: Infinite Memory Transformer. arXiv preprint arXiv:2109.00301, 2021.

    在transformer中引入一种长期记忆机制.

    主要内容

    假设(X in mathbb{R}^{L imes d}), 即每一行(x_i)代表一个token对应的特征.
    Attention需要进行如下的步骤:

    [Q = XW^Q, K = X W^K, V = XW^V, \ Z = mathrm{softmax}(frac{QK^T}{sqrt{d}})V. ]

    为了符号简易起见, 我们不考虑multi-head的情形, 下面的思想可以直接应用之.

    我们知道, 可以通过径向基函数来逼近任意的连续函数:

    [sum_{k} b_k psi_k (t) ightarrow f(t). ]

    现在, 我们令(t_i = frac{i}{L}), 即对(L)个tokens冠以时序, (X)的每一列都可以看成一个特殊的(f_j(t))的位于(t_i, i=0,1,cdots, L-1)处的值.
    给定(N)个基函数(psi_k (t), k=0,1,cdots, N-1), 我们要通过求解系数(m{b}_j = [b_{j0}, b_{j1},cdots b_{j,N-1}]^T)来逼近(f_j)((X)的第(j)列).
    (Psi in mathbb{R}^{N imes L}, Psi_{ki}=psi_{k}(t_i)), (B in mathbb{R}^{d imes N}, B_{jk} = b_{jk}).
    作者通过岭回归来求解系数(b):

    [B = arg min_{B} |B Psi - X^T|_F^2 + lambda |B|_F^2, ]

    其显示表达式为:

    [B = X^TPsi^T(PsiPsi^T + lambda I)^{-1}. ]

    [X^T approx BPsi ightarrow x_i approx B psi (t_i). ]

    现在我们用( ilde{X} := Psi^T B^T)来代替(X), 则

    [K = ilde{X} W^K = Psi^TB^TW^K, ilde{V} = ilde{X}W^V = Psi^TB^TW^V. ]

    注意, 我们并不对(Q)进行替换, 因为这个只是用作长期的记录用, Q每次重新计算.
    对于每个(q_i), 我们构建一个其关于(t)的密度函数(p_i(t)), 文中假设其满足高斯分布:

    [mathcal{N}(t; mu_i; sigma_i^2). ]

    (mu_i, sigma_i^2)分别通过如下估计:

    [mu_i = mathrm{sigmoid} (w_{mu}^T K q_i) =mathrm{sigmoid} (w_{mu}^T B^TW^K q_i), \ sigma^2_i = mathrm{softplus} (w_{sigma}^T K q_i) =mathrm{softplus} (w_{sigma}^T B^TW^K q_i). \ ]

    注意最后令(w^TPsi^T = w^T)既然(Psi)是事先确定的.
    我们知道

    [mathrm{softmax}(frac{Kq_i}{sqrt{d}}) ]

    实际上求解的是一个离散化的(p_i(t)), 即(q_i)(k_j)的相合程度, 而

    [mathrm{softmax}(frac{Kq_i}{sqrt{d}})^TV ]

    实际上就是求解期望

    [mathbb{E}_{p_i}[v(t)]. ]

    现在我们近似了一个连续的(p_i(t)), 也可以通过这种方式得到最后的(z_i):

    [mathbb{E}_{p_i}[v(t)] =mathbb{E}_{p_i}[psi^T(t)B^TW^V] =mathbb{E}_{p_i}[psi^T(t)]B^TW^V. ]

    当我们取(psi)为高斯径向基函数的时候, 上述是由显示解的.

    现在来剖析一下, 好在哪里?
    原本的(K)(L imes d)的, 现在由于我们只需要计算(B^TW), 故实际上只有(N imes d), 我们可以选取很大的(L)但是选择较小的(N)来避免较高的复杂度.

    如何扩展?

    难不成每一次都要重新计算(B)? 倘若真的是这样就谈不上是长期记忆了.
    作者采取了一种比较巧的方法, 实际上, 现在的(Bpsi(t))可以看成是一个(d)维的向量函数.
    我们首先将其进行压缩至([0, au], au in (0, 1)):

    [Bpsi(t / au), ]

    如此一来, 整个函数的能量集中在([0, au])中, 我们可以用剩下的(( au, 1])来放置新的(X).
    我们首先从([0, au])中采样(M)个点(t_0, cdots, t_{M-1}), 并得到:

    [X_{past} = [x_0, cdots, x_{M-1}]^T in mathbb{R}^{M imes d}, x_m=psi^T(t_m/ au)B^T. ]

    加上新的(X_{new}), 我们有

    [X = [X_{past}^T, X_{new}^T]^T in mathbb{R}^{(M + L) imes d}, ]

    (X)按照上面的逻辑重新估计(B)即可更新记忆.

    关于如何采样这(M)个点, 作者提了一种sticky memories的方法, 将其与密度函数联系在一起, 便不细讲了.

    实验细节

    在看这篇论文的时候, 困扰我的就是这个径向基函数是怎么选的?
    举一个作者在Language Modeling中的例子便可:
    选取150个高斯径向基函数(mathcal{N}(t;mu, sigma^2)), 其中
    (mu)([0, 1])中均匀采样, (sigma in {0.01, 0.05}).

    还有用KL散度防止一般化就不讲了. 感觉本文有趣的点就是压缩这个地方, 还有对(Psi)的处理.

  • 相关阅读:
    解释机器学习模型的一些方法(一)——数据可视化
    机器学习模型解释工具-Lime
    Hive SQL 语法学习与实践
    LeetCode 198. 打家劫舍(House Robber)LeetCode 213. 打家劫舍 II(House Robber II)
    LeetCode 148. 排序链表(Sort List)
    LeetCode 18. 四数之和(4Sum)
    LeetCode 12. 整数转罗马数字(Integer to Roman)
    LeetCode 31. 下一个排列(Next Permutation)
    LeetCode 168. Excel表列名称(Excel Sheet Column Title)
    论FPGA建模,与面向对象编程的相似性
  • 原文地址:https://www.cnblogs.com/MTandHJ/p/15339926.html
Copyright © 2011-2022 走看看