这是优美的多项式家族
快速傅里叶变换(FFT)
问题:多项式乘法
原理先不写了,思想就是把系数表达转化为点值表达,点值运算之后再变回系数表达,复杂度(O(nlogn))
点值选取的是负数域中的n次单位根
有时间会补上这块内容的
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
const int N = 4e6;
const double Pi = acos(-1.0);
using namespace std;
struct node
{
double x,y;
}a[N + 5],b[N + 5],w[N + 5];
int n,m,maxn,rev[N + 5],lg;
node operator +(node a,node b)
{
return (node){a.x + b.x,a.y + b.y};
}
node operator -(node a,node b)
{
return (node){a.x - b.x,a.y - b.y};
}
node operator *(node a,node b)
{
return (node){a.x * b.x - a.y * b.y,a.x * b.y + a.y * b.x};
}
void fft(node *a,int typ)
{
for (int i = 0;i < maxn;i++)
if (i < rev[i])
swap(a[i],a[rev[i]]);
for (int i = 1;i < maxn;i <<= 1)
for (int j = 0;j < maxn;j += i << 1)
for (int k = 0;k < i;k++)
{
node x = a[k + j],t = (node){w[i + k].x,w[i + k].y * typ} * a[k + j + i];
a[k + j] = x + t;
a[k + j + i] = x - t;
}
}
int main()
{
scanf("%d%d",&n,&m);
for (int i = 0;i <= n;i++)
scanf("%lf",&a[i].x);
for (int i = 0;i <= m;i++)
scanf("%lf",&b[i].x);
maxn = 1;
while (maxn <= m + n)
maxn <<= 1,lg++;
for (int i = 0;i <= maxn;i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
for (int i = 1;i < maxn;i <<= 1)
for (int j = 0;j < i;j++)
w[i + j] = (node){cos(Pi * j / i),sin(Pi * j / i)};
fft(a,1);
fft(b,1);
for (int i = 0;i < maxn;i++)
a[i] = a[i] * b[i];
fft(a,-1);
for (int i = 0;i <= n + m;i++)
printf("%d ",(int)(a[i].x / maxn + 0.1));
return 0;
}
快速数论变换(NTT)
就是把问题转化为了在模意义下,于是我们可以选择和单位根有类似性质的原根,时间复杂度仍是(O(nlogn))
#include <iostream>
#include <cstdio>
#include <algorithm>
const int N = 5e6;
const int P = 998244353;
using namespace std;
int n,m,rev[N + 5],maxn,lg,a[N + 5],b[N + 5],g[N + 5][3];
int mypow(int a,int x)
{
int s = 1;
while (x)
{
if (x & 1)
s = 1ll * s * a % P;
a = 1ll * a * a % P;
x >>= 1;
}
return s;
}
void ntt(int *a,int typ)
{
for (int i = 0;i < maxn;i++)
if (i < rev[i])
swap(a[i],a[rev[i]]);
for (int i = 1;i < maxn;i <<= 1)
for (int j = 0;j < maxn;j += i << 1)
for (int k = 0;k < i;k++)
{
int x = a[k + j],t = 1ll * g[k + i][typ] * a[k + i + j] % P;
a[k + j] = (x + t) % P;
a[k + i + j] = ((x - t) % P + P) % P;
}
}
int main()
{
scanf("%d%d",&n,&m);
for (int i = 0;i <= n;i++)
scanf("%d",&a[i]);
for (int i = 0;i <= m;i++)
scanf("%d",&b[i]);
maxn = 1;
while (maxn <= n + m)
maxn <<= 1,lg++;
for (int i = 0;i <= maxn;i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
for (int i = 1;i < maxn;i <<= 1)
{
int G1 = mypow(3,(P - 1) / (i << 1)),G2 = mypow(mypow(3,P - 2),(P - 1) / (i << 1));
g[i][1] = 1;
g[i][0] = 1;
for (int j = 1;j < i;j++)
g[i + j][1] = 1ll * g[i + j - 1][1] * G1 % P,g[i + j][0] = 1ll * g[i + j - 1][0] * G2 % P;
}
ntt(a,1);
ntt(b,1);
for (int i = 0;i < maxn;i++)
a[i] = 1ll * a[i] * b[i] % P;
ntt(a,0);
int inv = mypow(maxn,P - 2);
for (int i = 0;i <= n + m;i++)
printf("%d ",1ll * a[i] * inv % P);
return 0;
}
多项式求逆
问题:给定一个多项式(F(x)),求一个多项式(G(x)),满足(F(x)G(x)equiv 1(mod x^n))
假设我们已经求出了一个(F(x))在(mod x^n)下的逆(G'(x)),我们要求在(mod x^{2n})下的逆(G(x))
那么考虑
于是就可以愉快地递归求解了,时间复杂度(T(n)=T(n/2)+O(nlogn)=O(nlogn))
Code
int INVa[N + 5];
void INV(int *a,int *ans,int n)
{
if (n == 1)
{
ans[0] = mypow(a[0],p - 2);
return;
}
INV(a,ans,n + 1 >> 1);
pre(n * 2);
for (int i = 0;i < n;i++)
INVa[i] = a[i];
clear(INVa,maxn,n);
ntt(INVa,1);
ntt(ans,1);
for (int i = 0;i < maxn;i++)
ans[i] = (2ll * ans[i] % p - 1ll * INVa[i] * ans[i] % p * ans[i] % p) % p;
ntt(ans,0);
clear(ans,maxn,n);
}
多项式对数函数(多项式 ln)
问题:给出 (n-1) 次多项式 (A(x)),求一个 (mod{:x^n}) 下的多项式 (B(x)),满足 (B(x) equiv ln A(x)).
对两边同时求导(B'(x)equiv frac{A'(x)}{A(x)})
积分回去(B(x)equiv int frac{A'(x)}{A(x)}dx)
然后就是求导公式和积分公式
Code
int Lna[N + 5],Lnb[N + 5];
void DOV(int *a,int *f,int n)
{
for (int i = 1;i < n;i++)
f[i - 1] = 1ll * i * a[i] % p;
f[n - 1] = 0;
}
void DOVINV(int *a,int *f,int n)
{
f[0] = 0;
for (int i = 1;i < n;i++)
f[i] = 1ll * mypow(i,p - 2) * a[i - 1] % p;
}
void Ln(int *a,int *ans,int n)
{
DOV(a,Lna,n);
pre(n * 2);
clear(Lnb,maxn);
INV(a,Lnb,n);
pre(n * 2);
clear(Lna,maxn,n);
ntt(Lna,1);
ntt(Lnb,1);
for (int i = 0;i < maxn;i++)
Lna[i] = 1ll * Lna[i] * Lnb[i] % p;
ntt(Lna,0);
DOVINV(Lna,ans,n);
clear(ans,maxn,n);
}
多项式指数函数(多项式 exp)
问题:给出 (n-1) 次多项式 (A(x)),保证(A_0=0),求一个 (mod{:x^n}) 下的多项式 (B(x)),满足 (B(x) equiv ext e^{A(x)})。
考虑用牛顿迭代解决这个问题
设(F(B(x))=lnB(x)-A(x))
把(A(x))看作常数项,所以(F'(B(x))=frac{1}{B(x)})
代入牛顿迭代的式子有
倍增求解即可
Code
int expa[N + 5],expb[N + 5];
void exp(int *a,int *ans,int n)
{
if (n == 1)
{
ans[0] = 1;
return;
}
exp(a,ans,n + 1 >> 1);
Ln(ans,expa,n);
pre(n * 2);
for (int i = 0;i < n;i++)
expb[i] = a[i];
clear(expb,maxn,n);
ntt(ans,1);
ntt(expa,1);
ntt(expb,1);
for (int i = 0;i < maxn;i++)
ans[i] = 1ll * ans[i] * ((1 - expa[i] + expb[i]) % p) % p;
ntt(ans,0);
clear(ans,maxn,n);
}
多项式快速幂
问题:给定一个 (n-1) 次多项式 (A(x)),求一个在 (mod x^n) 意义下的多项式 (B(x)),使得 (B(x) equiv A^k(x) (mod x^n))
我们对两边先ln再exp可以得到
于是(k)也可以取模了
然后注意到数据不一定保证(A_0=1),那么我们可以找到第一个非(0)的项(a),把(A(x))的每一项都除以(a),变成(frac{A(x)}{a}),并将后面的移到前面,这样就可以保证(A_0=1),最后再乘(a^k)并且处理(0)即可
Code
int pa[N + 5];
void mypow(int *a,int *ans,int n,int k)
{
Ln(a,pa,n);
for (int i = 0;i < n;i++)
pa[i] = 1ll * pa[i] * k % p;
exp(pa,ans,n);
}
多项式开根
问题:给定一个(n-1)次多项式(A(x)),求一个在(mod x^n)意义下的多项式(B(x)),使得(B^2(x) equiv A(x) (mod x^n))。若有多解,请取零次项系数较小的作为答案。
设(H^2(x)equiv F(x)(mod x^n))
那么考虑
倍增即可,只有一项的时候需要用二次剩余求根号
不过其实也可以先ln再exp回去
Code
int sqra[N + 5],sqrtmp[N + 5];
void sqr(int *a,int *ans,int n)
{
if (n == 1)
{
ans[0] = sq;
return;
}
sqr(a,ans,n + 1 >> 1);
pre(n * 2);
clear(sqra,maxn);
clear(sqrtmp,maxn);
INV(ans,sqra,n);
pre(n * 2);
for (int i = 0;i < n;i++)
sqrtmp[i] = a[i];
ntt(sqra,1);
ntt(sqrtmp,1);
ntt(ans,1);
int t = mypow(2,p - 2);
for (int i = 0;i < maxn;i++)
ans[i] = 1ll * ((sqrtmp[i] + 1ll * ans[i] * ans[i] % p) % p) * t % p * sqra[i] % p;
ntt(ans,0);
int inv = mypow(maxn,p - 2);
for (int i = 0;i < n;i++)
ans[i] = 1ll * ans[i] * inv % p;
clear(ans,maxn,n);
}
多项式除法
问题:给定一个(n)次多项式(F(x))和一个(m)次多项式(G(x)),求出多项式(Q(x),R(x))满足:
- (Q(x))次数为(n-m),(R(x))次数小于(m)
- (F(x)=Q(x)G(x)+R(x))
首先设一个(n)项多项式(A(x)),假设一个(r)操作使得(A_r(x)=x^nA(frac{1}{x}))
那么可以看出(A_r[i]=A[n-i])
然后考虑下面的式子
于是我们对(G_r(x))求逆,然后求得(Q_r(x)),再带回得到(Q(x))
最后根据(R(x)=F(x)-Q(x)G(x))求得(R(x))
时间复杂度(O(nlogn))
Code
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
const int P = 998244353;
const int N = 1e6;
using namespace std;
int mypow(int a,int x)
{
int s = 1;
while (x)
{
if (x & 1)
s = 1ll * s * a % P;
a = 1ll * a * a % P;
x >>= 1;
}
return s;
}
int n,m,F[N + 5],G[N + 5],Q[N + 5],GR[N + 5],w[N + 5][3],maxn,lg,rev[N + 5],Gi[N + 5],c[N + 5],FR[N + 5];
void R(int *a,int *b,int n)
{
for (int i = 0;i <= n;i++)
b[i] = a[n - i];
}
void ntt(int *a,int typ)
{
for (int i = 0;i < maxn;i++)
if (i < rev[i])
swap(a[i],a[rev[i]]);
for (int i = 1;i < maxn;i <<= 1)
for (int j = 0;j < maxn;j += i << 1)
for (int k = 0;k < i;k++)
{
int x = a[j + k],t = 1ll * w[i + k][typ] * a[i + j + k] % P;
a[j + k] = (x + t) % P;
a[j + k + i] = ((x - t) % P + P) % P;
}
}
void ntt_pre(int n)
{
maxn = 1;
lg = 0;
while (maxn <= n)
maxn <<= 1,lg++;
for (int i = 0;i < maxn;i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
}
void INV(int n,int *a,int *b)
{
if (n == 1)
{
b[0] = mypow(a[0],P - 2);
return;
}
INV((n + 1) >> 1,a,b);
ntt_pre(n << 1);
for (int i = 0;i < n;i++)
c[i] = a[i];
for (int i = n;i < maxn;i++)
c[i] = 0;
ntt(c,1);
ntt(b,1);
for (int i = 0;i < maxn;i++)
b[i] = ((2ll * b[i] % P - 1ll * c[i] * b[i] % P * b[i] % P) % P + P) % P;
ntt(b,2);
int inv = mypow(maxn,P - 2);
for (int i = 0;i < n;i++)
b[i] = 1ll * b[i] * inv % P;
for (int i = n;i < maxn;i++)
b[i] = 0;
}
void NR(int *a,int *b,int n)
{
for (int i = 0;i <= n;i++)
b[n - i] = a[i];
}
int main()
{
scanf("%d%d",&n,&m);
for (int i = 0;i <= n;i++)
scanf("%d",&F[i]);
for (int i = 0;i <= m;i++)
scanf("%d",&G[i]);
maxn = 1;
while (maxn <= (n + m) * 2)
maxn <<= 1;
for (int i = 1;i < maxn;i <<= 1)
{
int G1 = mypow(3,(P - 1) / (i << 1)),G2 = mypow(mypow(3,P - 2),(P - 1) / (i << 1));
w[i][1] = w[i][2] = 1;
for (int j = 1;j < i;j++)
w[i + j][1] = 1ll * w[i + j - 1][1] * G1 % P,w[i + j][2] = 1ll * w[i + j - 1][2] * G2 % P;
}
R(G,GR,m);
INV(n - m + 2,GR,Gi);
R(F,FR,n);
ntt_pre(n * 2 - m + 2);
ntt(FR,1);
ntt(Gi,1);
for (int i = 0;i < maxn;i++)
Gi[i] = 1ll * Gi[i] * FR[i] % P;
ntt(Gi,2);
int inv = mypow(maxn,P - 2);
for (int i = 0;i < maxn;i++)
Gi[i] = 1ll * Gi[i] * inv % P;
NR(Gi,Q,n - m);
for (int i = 0;i <= n - m;i++)
printf("%d ",Q[i]);
cout<<endl;
for (int i = n - m + 1;i < maxn;i++)
Q[i] = 0;
ntt_pre(n + m);
ntt(Q,1);
ntt(G,1);
ntt(F,1);
for (int i = 0;i < maxn;i++)
F[i] = ((F[i] - 1ll * Q[i] * G[i] % P) % P + P) % P;
ntt(F,2);
inv = mypow(maxn,P - 2);
for (int i = 0;i < m;i++)
printf("%d ",1ll * F[i] * inv % P);
return 0;
}