[笔记] [题解]多项式学习
(mathcal{FFT})
参考博客:( extit{litble})学长的校内博客
前置知识
在接下来的讲解中可能会用到一些高中数学知识,现在先稍微讲解一下(主要是我不会啊)
虚数&复数
定义
虚数的基本单位:(i=sqrt{-1})
复数:一个复数(x)可以表示为(x = a+bi)
运算
加减:(large (a + bi)+(c + di)=(a+c)+(b+d)i)
乘除:(large (a+bi)(c+di)=ac+bci+adi-bd=(ac-bd)+(bc+ad)i)
其他形式
复数的三角形式如图
偷来的图
其中( heta)是复数的辐角,可以表示成( heta+2kpi)的形式。
那么我们可以得出复数乘除运算的三角式:
(large r_1(cos heta_1+i sin heta_1)r_2(cos heta_2+i sin heta_2)=r_1r_2(cos( heta_1+ heta_2)+sin( heta_1+ heta_2)))
也就是说复数积的模等于各个复数的模的积,乘积的辐角等于各个复数辐角的和。
单位根
(n)次单位根就是满足(omega^n=1)的复数
由复数的三角式乘除运算法则可以知道有(n)个这样的复数,它们分布于平面的单位圆上,并且平分这个单位圆。
(n)次单位根是:(large e^{frac{2pi ki}{n}},k=0,1,2,dots,n-1)
还有一个接下来的变形中会用到的公式:欧拉公式:(e^{ heta i}= ext{cos} heta+i ext{sin} heta)
因此就可以得到(n)次单位根的算术表示法:记(omega_n=e^{frac{2pi i}{n}})
总结一下单位根的性质:
多项式乘法
给定多项式(large A(x)=sum^n_{i=0}a_ix^i)和(large B(x)=sum^n_{i=0}b_ix^i),则它们的积是(large C(x)=sum_{j+k=i,0leq j,kleq n} a_jb_kx^i)
(mathcal{FFT})具体知识
折半引理
对于(n>0)且(n)为奇数有:
证明:(large (omega_n^{k+frac{n}{2}})^2=omega_n^{2k+n}=omega^{2k}_nomega^n_n=(omega^k_n)^2=omega^k_{n/2})。可以参照上面的欧拉公式。
快速傅里叶变换
就是多项式快速转化为点值表示法。
首先进行奇偶性划分:
所以就把(A(x))变成了(A_0(x^2)+xA_1(x^2))
同时用复数(omega^k_n)来加速。(在下文蝶形变换有差不多的解释,看不懂的没关系)
可以得到(large A(omega^k_n)=A_0((omega^k_n)^2)+omega^k_n A_1((omega^k_n)^2)------------(1))
由于(large (e^{frac{2pi i}{n}})^{2k}=(e^{frac{4pi i}{n}})^k)
所以(large A(omega^k_n)=A_0(omega^k_{n/2})+omega^k_{n/2})
根据折半引理,(large omega_n^{k+frac{n}{2}}=(e^{frac{2pi i}{n}})^{k+frac{n}{2}}=(e^{frac{2pi i}{n}})^ke^{pi i}=omega^k_nomega_n^{frac{n}{2}}=-omega^k_n)
可以得到(large A(omega_n^{k+n/2})=A_0(omega^k_{n/2})-omega^k_nA_1(omega^k_{n/2})-----------(2))
当(large kin[0,frac{n}{2}-1])时,(large k+frac{n}{2}in[frac{n}{2},n-1])
这样利用分支来实现的复杂度是(O(nlog_2n))
蝶形变换
这种算法的英文名称是(Cooley-Tukey)算法。
假设现在有一个(n-1)次多项式(large A(x)=sum^{n-1}_{i=0}a_ix^i)(方便起见,设(n=2^m,min))
将(n)个(n)次单位根(omega^0_n,omega^1_n,dots,omega^{n-1}_n)带入多项式(A(x))将其转换成点值表达
接下来把每一项进行奇偶分类
前面有提到(large omega^2_n=(e^{frac{2pi i}{n}})=e^{frac{2pi i}{n/2}}=omega_{frac{n}{2}}),也就是说要带入的值经过平方之后变少了一半,原因是单位根把单位元平分,那么肯定具有对称性,所以说肯定有一正一负两个,平方之后自然就相等了。
也就是说当(k<frac{n}{2})时
这样我们带入的值也就变成了(large 1,omega_{frac{n}{2}}^1,omega_{frac{n}{2}}^2,dots,omega_{frac{n}{2}}^{frac{n}{2}-1},)也就是把单位圆上的单位根一次代入,这样的复杂度就是(large O(nlog_2n))
举一个具体一点的例子来描述一下奇偶分类的具体过程:
初始的系数:(large omega_n^0omega_n^1omega_n^2omega_n^3omega_n^4omega_n^5omega_n^6omega_n^7)
一次变换后:(large omega^0_nomega^2_nomega^4_nomega^6_nomega^1_nomega^3_nomega^5_nomega^7_n)
两次变换后:(large omega^0_nomega^4_nomega^2_nomega^6_nomega^1_nomega^5_nomega^3_nomega^7_n)
傅里叶逆变换
我目前不是很懂,不过过程是:把原来傅里叶变换中(omega_n^i)换成(omega_n^{-i}),然后做一次傅里叶变换,之后把得到的结果除以(n)即可。
代码实现
这个是这个题
#include <bits/stdc++.h>
using namespace std;
const int N = 3000010;
const double pi = 3.1415926535897384626;
struct complex_num{
double r,i;
}a[N],b[N];
int n,m,len,rev[N];
complex_num operator + (complex_num a,complex_num b){
return (complex_num){a.r + b.r,a.i + b.i};
}
complex_num operator - (complex_num a,complex_num b){
return (complex_num){a.r - b.r,a.i - b.i};
}
complex_num operator * (complex_num a,complex_num b){
return (complex_num){a.r * b.r - a.i * b.i,a.i * b.r + a.r * b.i};
}
complex_num operator / (complex_num a,double c){
return (complex_num){a.r / c,a.i / c};
}
void FFT(complex_num *a,int x){
for(int i = 0;i < n;i++)
if(i < rev[i])
swap(a[i],a[rev[i]]);//防止一个元素交换两次回到它原来的位置
for(int i = 1;i < n;i <<= 1){
complex_num wn = (complex_num){cos(pi / i),x * sin(pi / i)};
for(int j = 0;j < n;j += (i << 1)){
complex_num w = (complex_num){1,0},tmp1,tmp2;
for(int k = 0;k < i;k++,w = w * wn){
tmp1 = a[j + k],tmp2 = w * a[j + k + i];
a[j + k] = tmp1 + tmp2;a[j + k + i] = tmp1 - tmp2;
}
}
}
if(x == -1)for(int i = 0;i < n;i++)a[i] = a[i] / n;
}
int main(){
scanf("%d%d",&n,&m);
for(int i = 0;i <= n;i++)scanf("%lf",&a[i].r);
for(int i = 0;i <= m;i++)scanf("%lf",&b[i].r);
m = n + m;
for(n = 1;n <= m;n <<= 1)len++;
for(int i = 0;i < n;i++)rev[i] = (rev[i >> 1] >> 1) | (i & 1) << (len - 1);
FFT(a,1);FFT(b,1);
for(int i = 0;i <= n;i++)a[i] = a[i] * b[i];
FFT(a,-1);
for(int i = 0;i <= m;i++)printf("%d ",(int)(a[i].r + 0.5));
return 0;
}
还有这个题
#include <bits/stdc++.h>
using namespace std;
const long long N = 3000010;
const double pi = 3.1415926535897384626;
struct com{
double r,i;
}a[N],b[N];
long long ans[N];
long long n,m,len,rev[N];
char s[N];
com operator + (com a,com b){
return (com){a.r + b.r,a.i + b.i};
}
com operator - (com a,com b){
return (com){a.r - b.r,a.i - b.i};
}
com operator * (com a,com b){
return (com){a.r * b.r - a.i * b.i,a.r * b.i + b.r * a.i};
}
com operator / (com a,double c){
return (com){a.r / c,a.i / c};
}
void FFT(com *a,long long x){
for(long long i = 0;i < n;i++)if(i < rev[i])swap(a[i],a[rev[i]]);//防止交换两次,等同于没有交换
for(long long i = 1;i < n;i <<= 1){//i是准备合并的序列的长度的一半
com wn = (com){cos(pi / i),x * sin(pi / i)};//单位根
for(long long j = 0;j < n;j += (i << 1)){//j是合并到了哪一位
com w = (com){1,0},tmp1,tmp2;
for(long long k = 0;k < i;k++,w = w * wn){//只扫左半部分,同时得到右半部分的答案(蝴蝶变换)
tmp1 = a[j + k],tmp2 = w * a[j + k + i];
a[j + k] = tmp1 + tmp2;//对应上面快速傅里叶变换的(1)
a[j + k + i] = tmp1 - tmp2;//对应上面快速傅里叶变换的(2)
}
}
}
if(x == -1)for(long long i = 0;i < n;i++)a[i] = a[i] / n;
}
signed main(){
scanf("%s",&s);
n = strlen(s);
for(long long i = n - 1;i >= 0;i--)a[n - i - 1].r = s[i] - '0';//下标从0开始
scanf("%s",&s);
m = strlen(s);
for(long long i = m - 1;i >= 0;i--)b[m - i - 1].r = s[i] - '0';
m = n + m;
for(n = 1;n <= m;n <<= 1)len++;
for(long long i = 0;i < n;i++)rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (len - 1)));
FFT(a,1);FFT(b,1);
for(long long i = 0;i <= n;i++)a[i] = a[i] * b[i];
FFT(a,-1);
for(long long i = 0;i <= m;i++)ans[i] = (long long)(a[i].r + 0.5);
len = 0;
for(long long i = 0;i <= m;i++){//算出来的是系数,可能大于10,要进位
ans[i + 1] += ans[i] / 10;
ans[i] %= 10;
len++;
}
while(ans[len] >= 10){
ans[++len] += ans[len - 1] / 10;
ans[len - 1] %= 10;
}
while(ans[len] == 0)len--;
for(long long i = len;i >= 0;i--)printf("%d",ans[i]);
printf("
");
return 0;
}
一些对于程序的解释
在第二个程序中,经过第(46)行的操作,可以得到(2^kge2n)且(2^{k-1}<2n),(n=2^k)。这样是为了使无论将单位圆分成几份都会是整数份数。
第(47)行的程序的作用是:因为我们需要进行奇偶分类,这里就有一个性质,比如说现在要算出下标为(4)的元素在奇偶分类之后排在哪一位,那么我们先表示出(4)的二进制数(0100),再把这个二进制数颠倒得到(0010)对应的十进制数就是下标为(4)的元素的位置,位于第(2)个。可以参见上面奇偶分类系数的变换。
在(mathcal{FFT})函数中,三重循环的三个循环变量(i,j,k)分别代表:把单位圆分成几份,从第几个单位根开始在单位圆上转,当前计算到了哪一个单位根。
(mathcal{NTT})
在实现(mathcal{FFT})的时候我们会发现其实在计算过程中是有精度损失的,因为我们利用(omega_n)实现了折半引理。有没有什么整数可以用来代替(omega_n)呢?
原根可以取而代之。定义(P)的原根为满足(large g^{phi(P)}equiv1(mod P))的整数(g)。
我们用(large g^{frac{phi(P)}{n}})代替(n)次单位根进行计算,因为(P)是质数,所以(phi(P)=P-1),有要求(large frac{phi(P)}{n})为整数,(n)还是(2)的整数次幂,所以要求(large P=k*2^q+1),其中(2^qge n)。
怎么求原根呢?如果题目没有给出模数,就要用(mathcal{BSGS}),如果(P)不是质数就要用中国剩余定理合并。
另一种定义是:若有(g)使得(g^imod P)的结果两两不同,(P)为质数,且(gin[2,p-1]i,iin[1,p-1]),那么称(g)是(P)的原根。比如说(998244353)的原根就是(3)。
代码的坑还没填。。。
拉格朗日插值
先放一道例题
题目大意
给出(n)个点(P_i(x_i,y_i)),将过这(n)个点的最多(n-1)次的多项式记为(f(x)),求(f(k))的值。
拉格朗日插值
设我们现在有给定的(n+ 1)个点,分别是((x_0,y_0),(x_1,y_1),dots,(x_n,y_n))
则拉格朗日基本多项式为
我们可以发现(large ell_j(x_j)=1),并且(large ell_j(x_i)=0,forall i e j),也就是说(large ell_j(x_i))函数的作用就是让函数的返回值只有(0)或(1),而且在传入(x_j)的时候返回(1),其余时候返回(0)。
接着就是(n)次多项式
观察上式,我们可以发现(large P(x_i)=y_i),也就是经过了给定的(n+1)个点。
整合上面的两个公式得到最终的拉格朗日插值法的公式:
对于例题而言只要求出(f(k))的值即可。
代码
#include <bits/stdc++.h>
using namespace std;
const long long mod = 998244353;
long long x[20010],y[20010],ans,tmp1,tmp2;
inline long long qpow(long long x,long long y){
long long res = 1ll;
while(y){
if(y & 1){
res = res * x;
res %= mod;
}
x = x * x;x %= mod;
y >>= 1;
}
return res;
}
long long n,k;
int main(){
scanf("%lld%lld",&n,&k);
for(int i = 1;i <= n;i++)scanf("%lld%lld",&x[i],&y[i]);
for(int i = 1;i <= n;i++){
tmp1 = y[i] % mod;
tmp2 = 1ll;
for(int j = 1;j <= n;j++){
if(i != j)
tmp1 = tmp1 * (k - x[j]) % mod,tmp2 = tmp2 * (x[i] - x[j]) % mod;
}
ans += tmp1 * qpow(tmp2,mod - 2) % mod;
}
printf("%lld
",(ans % mod + mod) % mod);
return 0;
}
多项式操作
多项式求逆
定义
对于一个多项式(A(x))如果存在(B(x))满足(B)的次数不大于(A)并且
那么称(B(x))为(A(x))在(mod x^n)意义下的逆元,记作(A^{-1}(x))
(mod x^n)是忽略次数(ge n)的项。
求解方法
假设(A(x))在(mod x^{frac{n}{2}})的意义下的逆元为(B_0(x)),那么就有
再把上面两个式子做差,得到:
再进行化简:
左右两边同时平方:
多项式长度翻倍后上式依然成立:
左右两边同时乘以(A(x))并且由于(A(x)B(x)equiv 1(mod x^n)),所以可以化简:
再经过移项就得到了最终的结果:
这个式子可以倍增或者递归来求。
代码
#include<bits/stdc++.h>
using namespace std;
const int mod = 998244353,G = 3,N = 2100000;
int n;
int a[N],b[N],c[N],rev[N];
inline int qpow(int x,int y) {
int res = 1;
while(y){
if(y & 1){
res = 1LL * res * x % mod;
}
x = 1LL * x * x % mod;
y >>= 1;
}
return res;
}
inline void NTT(int *a,int n,int x) {
for(int i = 0;i < n;i++)
if(i < rev[i])
swap(a[i],a[rev[i]]);
for(int i = 1;i < n;i <<= 1) {
int gn = qpow(G,(mod - 1) / (i << 1));
for(int j = 0;j < n;j += (i << 1)) {
int t1,t2,g = 1;
for(int k = 0;k < i;k++,g = 1LL * g * gn % mod) {
t1 = a[j + k],t2 = 1LL * g * a[j + k + i] % mod;
a[j + k] = (t1 + t2) % mod,a[j + k + i] = (t1 - t2 + mod) % mod;
}
}
}
if(x == 1)return;
int inv = qpow(n,mod - 2);
reverse(a + 1,a + n);
for(int i = 0;i < n;i++) a[i] = 1LL * a[i] * inv % mod;
}
void work(int deg,int *a,int *b) {
if(deg == 1){
b[0] = qpow(a[0],mod - 2);
return;
}
work((deg + 1) >> 1,a,b);
int len = 0,rhs = 1;
while(rhs < (deg << 1))rhs <<= 1,len++;
for(int i = 1;i < rhs;i++)rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (len - 1));
for(int i = 0;i < deg;i++)c[i] = a[i];
for(int i = deg;i < rhs;i++)c[i] = 0;
NTT(c,rhs,1),NTT(b,rhs,1);
for(int i = 0;i < rhs;i++)
b[i] = 1LL * (2 - 1LL * c[i] * b[i] % mod + mod) % mod * b[i] % mod;
NTT(b,rhs,-1);
for(int i = deg;i < rhs;i++)b[i] = 0;
}
int main(){
scanf("%d",&n);
for(int i = 0;i < n;i++)scanf("%d",&a[i]);
work(n,a,b);
for(int i = 0;i < n;i++)printf("%d ",b[i]);
return 0;
}
未完待续...