zoukankan      html  css  js  c++  java
  • $infty$-former: Infinite Memory Transformer

    Martins P., Marinho Z. and Martins A. (infty)-former: Infinite Memory Transformer. arXiv preprint arXiv:2109.00301, 2021.

    在transformer中引入一种长期记忆机制.

    主要内容

    假设(X in mathbb{R}^{L imes d}), 即每一行(x_i)代表一个token对应的特征.
    Attention需要进行如下的步骤:

    [Q = XW^Q, K = X W^K, V = XW^V, \ Z = mathrm{softmax}(frac{QK^T}{sqrt{d}})V. ]

    为了符号简易起见, 我们不考虑multi-head的情形, 下面的思想可以直接应用之.

    我们知道, 可以通过径向基函数来逼近任意的连续函数:

    [sum_{k} b_k psi_k (t) ightarrow f(t). ]

    现在, 我们令(t_i = frac{i}{L}), 即对(L)个tokens冠以时序, (X)的每一列都可以看成一个特殊的(f_j(t))的位于(t_i, i=0,1,cdots, L-1)处的值.
    给定(N)个基函数(psi_k (t), k=0,1,cdots, N-1), 我们要通过求解系数(m{b}_j = [b_{j0}, b_{j1},cdots b_{j,N-1}]^T)来逼近(f_j)((X)的第(j)列).
    (Psi in mathbb{R}^{N imes L}, Psi_{ki}=psi_{k}(t_i)), (B in mathbb{R}^{d imes N}, B_{jk} = b_{jk}).
    作者通过岭回归来求解系数(b):

    [B = arg min_{B} |B Psi - X^T|_F^2 + lambda |B|_F^2, ]

    其显示表达式为:

    [B = X^TPsi^T(PsiPsi^T + lambda I)^{-1}. ]

    [X^T approx BPsi ightarrow x_i approx B psi (t_i). ]

    现在我们用( ilde{X} := Psi^T B^T)来代替(X), 则

    [K = ilde{X} W^K = Psi^TB^TW^K, ilde{V} = ilde{X}W^V = Psi^TB^TW^V. ]

    注意, 我们并不对(Q)进行替换, 因为这个只是用作长期的记录用, Q每次重新计算.
    对于每个(q_i), 我们构建一个其关于(t)的密度函数(p_i(t)), 文中假设其满足高斯分布:

    [mathcal{N}(t; mu_i; sigma_i^2). ]

    (mu_i, sigma_i^2)分别通过如下估计:

    [mu_i = mathrm{sigmoid} (w_{mu}^T K q_i) =mathrm{sigmoid} (w_{mu}^T B^TW^K q_i), \ sigma^2_i = mathrm{softplus} (w_{sigma}^T K q_i) =mathrm{softplus} (w_{sigma}^T B^TW^K q_i). \ ]

    注意最后令(w^TPsi^T = w^T)既然(Psi)是事先确定的.
    我们知道

    [mathrm{softmax}(frac{Kq_i}{sqrt{d}}) ]

    实际上求解的是一个离散化的(p_i(t)), 即(q_i)(k_j)的相合程度, 而

    [mathrm{softmax}(frac{Kq_i}{sqrt{d}})^TV ]

    实际上就是求解期望

    [mathbb{E}_{p_i}[v(t)]. ]

    现在我们近似了一个连续的(p_i(t)), 也可以通过这种方式得到最后的(z_i):

    [mathbb{E}_{p_i}[v(t)] =mathbb{E}_{p_i}[psi^T(t)B^TW^V] =mathbb{E}_{p_i}[psi^T(t)]B^TW^V. ]

    当我们取(psi)为高斯径向基函数的时候, 上述是由显示解的.

    现在来剖析一下, 好在哪里?
    原本的(K)(L imes d)的, 现在由于我们只需要计算(B^TW), 故实际上只有(N imes d), 我们可以选取很大的(L)但是选择较小的(N)来避免较高的复杂度.

    如何扩展?

    难不成每一次都要重新计算(B)? 倘若真的是这样就谈不上是长期记忆了.
    作者采取了一种比较巧的方法, 实际上, 现在的(Bpsi(t))可以看成是一个(d)维的向量函数.
    我们首先将其进行压缩至([0, au], au in (0, 1)):

    [Bpsi(t / au), ]

    如此一来, 整个函数的能量集中在([0, au])中, 我们可以用剩下的(( au, 1])来放置新的(X).
    我们首先从([0, au])中采样(M)个点(t_0, cdots, t_{M-1}), 并得到:

    [X_{past} = [x_0, cdots, x_{M-1}]^T in mathbb{R}^{M imes d}, x_m=psi^T(t_m/ au)B^T. ]

    加上新的(X_{new}), 我们有

    [X = [X_{past}^T, X_{new}^T]^T in mathbb{R}^{(M + L) imes d}, ]

    (X)按照上面的逻辑重新估计(B)即可更新记忆.

    关于如何采样这(M)个点, 作者提了一种sticky memories的方法, 将其与密度函数联系在一起, 便不细讲了.

    实验细节

    在看这篇论文的时候, 困扰我的就是这个径向基函数是怎么选的?
    举一个作者在Language Modeling中的例子便可:
    选取150个高斯径向基函数(mathcal{N}(t;mu, sigma^2)), 其中
    (mu)([0, 1])中均匀采样, (sigma in {0.01, 0.05}).

    还有用KL散度防止一般化就不讲了. 感觉本文有趣的点就是压缩这个地方, 还有对(Psi)的处理.

  • 相关阅读:
    C# 插件构架实战(Jack H Hansen )
    .Net 中的反射(动态创建类型实例) Part.4
    css3新添加属性>calc()
    详解IIS Express的详细配置、使用和注意事项
    SpringBoot 整合 Shiro 实现登录拦截
    java MD5 加密
    MyBatis xml foreach循环语句
    java 考试系统 在线学习 视频直播 人脸识别 springboot框架 前后分离 PC和手机端
    Spring Boot 事物回滚
    allowedOrigins cannot contain the special value "*" gateway 报错
  • 原文地址:https://www.cnblogs.com/MTandHJ/p/15339926.html
Copyright © 2011-2022 走看看