【模板】线性递推+BM算法
给出一个数列 (P) 从 (0) 开始的前 (n) 项,求序列 (P) 在(mod~998244353) 下的最短线性递推式,并在 (mod~ 998244353) 下输出 (P_m)。
(mleq 10^9,1leq nleq 10000),保证递推式最长不超过 (5000)。
Berlekamp-Massey 算法
Berlekamp-Massey 算法,常简称为 BM 算法,是用来求解一个数列的最短线性递推式的算法。
BM 算法可以在 (O(n^2)) 的时间内求解一个长度为 (n) 的数列的最短线性递推式。
基本定义
对于数列 (A={a_1,a_2,dots,a_n}),我们定义数列 (R={r_1,r_2,dots,r_m}) 为其线性递推式当且仅当
注意,无论你习惯从左到右写数列还是从右到左,这里的数列和线性递推式的位置对应关系是反着的。
所有可能的线性递推式 (R) 中长度 (m) 最小的叫做 (A) 的最短线性递推式。
算法流程
假设我们已经求得了 ({a_1,a_2,dots,a_{i-1}}) 的最短线性递推式 ({r_1,r_2,dots,r_m}),那么如何求得 ({a_1,a_2,dots,a_i}) 的最短线性递推式?
定义 ({a_1,a_2,dots,a_{i-1}}) 的最短线性递推式 ({r_1,r_2,dots,r_m}) 为当前递推式,记递推式被更改的次数为 (cnt),第 (i) 次更改后的递推式为 (R_i),那么当前递推式应当为 (R_{cnt})。特别地,定义 (R_0=varnothing)。
我们对每个版本的 (R),记一个表示差异量的数组 (Delta_i),满足
显然若 (Delta_i=0),那么当前递推式就是 ({a_1,a_2,dots,a_i}) 的最短线性递推式。
否则我们认为 (R_{cnt}) 在 (a_i) 处出错了,令 (fail_i) 为 (R_i) 最早的出错位置,则有 (fail_{cnt}=i)。考虑对 (R_{cnt}) 进行修改,使其变为 (R_{cnt+1}),并在 (a_i) 处同样成立。
若当前 (cnt=0),说明 (a_i) 是第一个非零元素,直接将 (R_1) 置为 ({ underbrace{0,0,dots,0}_{i} }) 即可,因为不可能用之前的数递推出 (a_i)。
否,即 (cnt>0),则考虑用 (R_{cnt}) 之前失败的递推式将这个 (Delta_i) 加回去((a=sum+Delta))。我们希望得到数列 (R'={r'_1,r'_2,...,r'_{m'}}),使得
-
[forall kin [m'+1,i-1],~sum_{j=1}^{m'}r'_ja_{k-j}=0 ]
-
[sum_{j=1}^{m'}r'_ja_{i-j}=Delta_i ]
如果能够找到这样的数列 (R'),那么令(R_{cnt+1}=R_{cnt}+R')即可。这里加号表示各位对应相加。
在之前失败的递推式中任选一个 (R_p),尝试在它的基础上修改,在 (i) 的位置上构造出 (Delta_{fail_p})(这里的 (Delta) 是对应 (R_p) 版本的),记得到的结果为 (R_p'),那么
考虑如何构造 (R'_p) 。将 (R_p) 的元素全部变成它的相反数,再在前面补上一个 (1) , (Delta_{fail_p}) 就到 (fail_p+1) 位置上来了。
[a_{fail_p}-sum_{i=1}^{m_p}R_{p,i}a_{fail_p-i}=Delta_{p,fail_p} ]
前面再补上 (i-fail_p-1) 个 (0),(Delta_{fail_p}) 就到 (i) 位置上来了。于是
这里乘号表示顺次连接。
又因为 (R_p) 在 (fail_p) 前的 (Delta=0),所以我们构造出来的 (R') 是满足第一条约束的。
为了保证得到的递推式长度最短,我们需要选取恰当的 (R_p)。容易看出,得到的 (R_{cnt+1}) 的长度为 (max(i-fail_p+m_p,m))。于是记录 (m_p-fail_p) 最短的递推式作为 (R_p) 。
至此我们完成了 BM 算法的理论部分,在最坏情况下,我们可能需要对数列进行 (O(n)) 次修改,因此该算法的时间复杂度为 (O(n^2))。
经验之谈
用 BM 得到的最短递推式长度最好要明显小于 (n) 的一半,否则需要再打些表。
为什么?因为若长度为 (frac n 2),可以看做 (frac n 2) 个变量列出 (frac n 2) 个方程,总能找到解。所以一个随机数列解出的最短递推式长度就是 (frac n 2) 左右。发生了这样的情况说明原数列很可能并没有一定的规律,即递推式大概率对之后的数据不适用。
另外因为计算中涉及除法,所以 BM 在实数域内求解可能有一定的精度误差。
namespace linear{
typedef vector<int> polynomial;
void num_trans(polynomial&a,int dir){
int lim=a.size();
static vector<int> rev,w[2];
if(rev.size()!=lim){
rev.resize(lim);
int len=log2(lim);
for(int i=0;i<lim;++i) rev[i]=rev[i>>1]>>1|(i&1)<<(len-1);
for(int dir=0;dir<2;++dir){
static co int g[2]={3,332748118};
w[dir].resize(lim);
w[dir][0]=1,w[dir][1]=fpow(g[dir],(mod-1)/lim);
for(int i=2;i<lim;++i) w[dir][i]=mul(w[dir][i-1],w[dir][1]);
}
}
for(int i=0;i<lim;++i)if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int step=1;step<lim;step<<=1){
int quot=lim/(step<<1);
for(int i=0;i<lim;i+=step<<1){
int j=i+step;
for(int k=0;k<step;++k){
int t=mul(w[dir][quot*k],a[j+k]);
a[j+k]=add(a[i+k],mod-t),a[i+k]=add(a[i+k],t);
}
}
}
if(dir){
int ilim=fpow(lim,mod-2);
for(int i=0;i<lim;++i) a[i]=mul(a[i],ilim);
}
}
polynomial poly_inv(polynomial a,int n){
polynomial b(1,fpow(a[0],mod-2));
if(n==1) return b;
int lim=2;
for(;lim<n;lim<<=1){
polynomial a1(a.begin(),a.begin()+lim);
a1.resize(lim<<1),num_trans(a1,0);
b.resize(lim<<1),num_trans(b,0);
for(int i=0;i<lim<<1;++i) b[i]=mul(add(2,mod-mul(a1[i],b[i])),b[i]);
num_trans(b,1),b.resize(lim);
}
a.resize(lim<<1),num_trans(a,0);
b.resize(lim<<1),num_trans(b,0);
for(int i=0;i<lim<<1;++i) b[i]=mul(add(2,mod-mul(a[i],b[i])),b[i]);
num_trans(b,1),b.resize(n);
return b;
}
polynomial operator/(polynomial f,polynomial g){
int n=f.size()-1,m=g.size()-1;
reverse(g.begin(),g.end()),g.resize(n-m+1),g=poly_inv(g,n-m+1);
reverse(f.begin(),f.end()),f.resize(n-m+1);
int lim=1<<int(ceil(log2((n-m)<<1|1)));
f.resize(lim),num_trans(f,0);
g.resize(lim),num_trans(g,0);
for(int i=0;i<lim;++i) f[i]=mul(f[i],g[i]);
num_trans(f,1),f.resize(n-m+1);
return reverse(f.begin(),f.end()),f;
}
polynomial operator%(polynomial f,polynomial g){
int n=f.size()-1,m=g.size()-1;
polynomial q=f/g;
int lim=1<<int(ceil(log2(n+1)));
q.resize(lim),num_trans(q,0);
g.resize(lim),num_trans(g,0);
for(int i=0;i<lim;++i) q[i]=mul(q[i],g[i]);
num_trans(q,1);
for(int i=0;i<m;++i) f[i]=add(f[i],mod-q[i]);
return f.resize(m),f;
}
int n,k;
void mul_mod(polynomial&a,polynomial b,co polynomial&p){
static co int lim=1<<int(ceil(log2(2*k-1)));
a.resize(lim),b.resize(lim);
num_trans(a,0),num_trans(b,0);
for(int i=0;i<lim;++i) a[i]=mul(a[i],b[i]);
num_trans(a,1),a.resize(2*k-1);
a=a%p;
}
void main(int _n,int _k,co vector<int>&_a,co vector<int>&_f){
n=_n,k=_k;
polynomial a(k),f(k);
for(int i=1;i<=k;++i) a[k-i]=mod-_a[i];
a.push_back(1);
for(int i=0;i<k;++i) f[i]=_f[i];
polynomial rmd(1,1),tmp(2);tmp[1]=1;
for(;n;n>>=1,mul_mod(tmp,tmp,a))
if(n&1) mul_mod(rmd,tmp,a);
int ans=0;
for(int i=0;i<k;++i) ans=add(ans,mul(rmd[i],f[i]));
printf("%d
",ans);
}
}
vector<int> ber_ma(vector<int> f){
vector<int> lst,cur;
int lsfa,lsdel;
for(int i=0;i<(int)f.size();++i){
int del=f[i];
for(int j=1;j<(int)cur.size();++j)
del=add(del,mod-mul(cur[j],f[i-j]));
if(!del) continue;
if(!cur.size()){
cur.resize(i+1),lsfa=i,lsdel=del;
continue;
}
int alph=mul(del,fpow(lsdel,mod-2));
vector<int> nw(i-lsfa);
nw.push_back(alph);
for(int j=1;j<(int)lst.size();++j)
nw.push_back(mul(alph,mod-lst[j]));
if(nw.size()<cur.size()) nw.resize(cur.size());
for(int j=1;j<(int)cur.size();++j)
nw[j]=add(nw[j],cur[j]);
if(i-lsfa+(int)lst.size()>=(int)cur.size())
lst=cur,lsfa=i,lsdel=del;
cur=nw;
}
return cur;
}
int main(){
int n=read<int>(),m=read<int>();
vector<int> f(n);
for(int i=0;i<n;++i) read(f[i]);
vector<int> a=ber_ma(f);
for(int i=1;i<(int)a.size();++i) printf("%d ",a[i]);
puts("");
if(m<=n) {printf("%d
",f[m]);return 0;}
linear::main(m,a.size()-1,a,f);
return 0;
}
线性递推式是 base 1 的,用 vector 存的话代码有点奇怪。