可以先了解下用分治法求多项式的值 链接。
求 $A(x)$ 在 $w_j,j=0,1,2...,n-1$ 处的值:
设多项式
$$A(x)=a_0+a_1*x+a_2*{x^2}+a_3*{x^3}+a_4*{x^4}+a_5*{x^5}+ dots+a_{n-2}*x^{n-2}+a_{n-1}*x^{n-1}$$
按下标的奇偶性分类,设
$$A_1(x)=a_0+a_2*{x}+a_4*{x^2}+dots+a_{n-2}*x^{frac{n}{2}-1}$$
$$A_2(x)=a_1+a_3*{x}+a_5*{x^2}+ dots+a_{n-1}*x^{frac{n}{2}-1}$$
那么不难得到 $A(x)=A_1(x^2)+xA_2(x^2)$.
我们将 $omega_n^k (k<frac{n}{2})$ 代入得
$$egin{aligned}
A(omega_n^k) &=A_1(omega_n^{2k})+omega_n^kA_2(omega_n^{2k}) \
&=A_1(omega_{frac{n}{2}}^{k})+omega_n^kA_2(omega_{frac{n}{2}}^{k})
end{aligned}$$
同理,将 $omega_n^{k+frac{n}{2}}$ 代入得
$$egin{aligned}
A(omega_n^{k+frac{n}{2}}) &=A_1(omega_n^{2k+n})+omega_n^{k+frac{n}{2}}(omega_n^{2k+n}) \
&=A_1(omega_n^{2k}*omega_n^n)-omega_n^kA_2(omega_n^{2k}*omega_n^n) \
&=A_1(omega_n^{2k})-omega_n^kA_2(omega_n^{2k})
end{aligned}$$
两个式子只有符号不同,也就是说,算出在前 $n$ 个点的值就能得到后 $n$ 个点的值,相当于问题规模减半了。
所以可以递归的实现,直到多项式仅剩一个常数项,这时候我们直接返回就好啦!
这样时间复杂度为 $O(nlogn)$.
FFT算法的伪代码:
1.求值 $A(w_j), B(w_j)$,$j=0,1,..2n-1$
2. 计算 $C(w_j)$,$j=0,1,...,2n-1$
3. 构造多项式
$D(x)=C(w_0) + C(w_1)x+...+C(w_{2n-1})x^{2n-1}$
4. 计算所有的 $D(w_j)$,$j=0,1,...2n-1$
5. 利用下式计算 $C(x)$ 的系数 $c_j$
$D(w_j) = 2nc_{2n-j}, j=1,...,2n-1$
$D(w_0) = 2nc_0$
递归版
fft函数 $type=1$ 进去的时候 $a$ 数组存的是系数,返回时存的是计算出来的值,$a[i]=A(w_i)$;
$type=-1$ 时相反。
这里预处理了sin和cos值,大概能快2、3倍。
// luogu-judger-enable-o2 #include<iostream> #include<cstdio> #include<cmath> using namespace std; const int MAXN = 4 * 1e6 + 10; //4倍空间 inline int read() { char c = getchar(); int x = 0, f = 1; while (c < '0' || c > '9') {if (c == '-')f = -1; c = getchar();} while (c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();} return x * f; } const double Pi = acos(-1.0); double ccos[MAXN], ssin[MAXN]; struct complex { double x, y; complex (double xx = 0, double yy = 0) {x = xx, y = yy;} } a[MAXN], b[MAXN]; complex operator + (complex a, complex b) { return complex(a.x + b.x , a.y + b.y);} complex operator - (complex a, complex b) { return complex(a.x - b.x , a.y - b.y);} complex operator * (complex a, complex b) { return complex(a.x * b.x - a.y * b.y , a.x * b.y + a.y * b.x);} //不懂的看复数的运算那部分 void fast_fast_tle(int limit, complex *a, int type) { if (limit == 1) return ; //只有一个常数项 complex a1[limit >> 1], a2[limit >> 1]; for (int i = 0; i < limit; i += 2) //根据下标的奇偶性分类 a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1]; fast_fast_tle(limit >> 1, a1, type); fast_fast_tle(limit >> 1, a2, type); complex Wn = complex(ccos[limit] , type * ssin[limit]), w = complex(1, 0); //complex Wn = complex(cos(2.0 * Pi / limit) , type * sin(2.0 * Pi / limit)), w = complex(1, 0); //Wn为单位根,w表示幂 for (int i = 0; i < (limit >> 1); i++, w = w * Wn) //这里的w相当于公式中的k { complex tmp = w * a2[i]; a[i] = a1[i] + tmp; a[i + (limit >> 1)] = a1[i] - tmp; //利用单位根的性质,O(1)得到另一部分 } } int main() { int N = read(), M = read(); //N,M是秩最高次幂 for (int i = 0; i <= N; i++) a[i].x = read(); for (int i = 0; i <= M; i++) b[i].x = read(); int limit = 1; while (limit <= N + M) limit <<= 1; for(int i = 1;i <= limit;i++) { ccos[i] = cos(2.0 * Pi / i); ssin[i] = sin(2.0 * Pi / i); } fast_fast_tle(limit, a, 1); fast_fast_tle(limit, b, 1); //后面的1表示要进行的变换是什么类型 //1表示从系数变为点值 //-1表示从点值变为系数 //至于为什么这样是对的,可以参考一下c向量的推导过程, for (int i = 0; i <= limit; i++) a[i] = a[i] * b[i]; fast_fast_tle(limit, a, -1); for (int i = 0; i <= N + M; i++) printf("%d ", (int)(a[i].x / limit + 0.5)); //按照我们推倒的公式,这里还要除以n return 0; }
非递归版
这个很容易发现点什么吧?
- 每个位置分治后的最终位置为其二进制翻转后得到的位置
这样的话我们可以先把原序列变换好,把每个数放在最终的位置上,再一步一步向上合并。
一句话就可以 $O(n)$ 预处理出位置 $i$ 最终的位置 $rev[i]$:
//原理也很简单,将高bit-1位(也就是i/2)反转,再将第一位补到最高位。
fo(i,0,n-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
// luogu-judger-enable-o2 #include<iostream> #include<cstdio> #include<cmath> using namespace std; const int MAXN = 4e6 + 10; //开4倍空间 inline int read() { char c = getchar(); int x = 0, f = 1; while (c < '0' || c > '9') {if (c == '-')f = -1; c = getchar();} while (c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();} return x * f; } const double Pi = acos(-1.0); struct complex { double x, y; complex (double xx = 0, double yy = 0) {x = xx, y = yy;} } a[MAXN], b[MAXN]; complex operator + (complex a, complex b) { return complex(a.x + b.x , a.y + b.y);} complex operator - (complex a, complex b) { return complex(a.x - b.x , a.y - b.y);} complex operator * (complex a, complex b) { return complex(a.x * b.x - a.y * b.y , a.x * b.y + a.y * b.x);} //不懂的看复数的运算那部分 int N, M; int bit, r[MAXN]; int limit = 1; void fast_fast_tle(complex *A, int type) { for (int i = 0; i < limit; i++) if (i < r[i]) swap(A[i], A[r[i]]); //求出要迭代的序列 for (int mid = 1; mid < limit; mid <<= 1) { //待合并区间的长度的一半 complex Wn( cos(Pi / mid) , type * sin(Pi / mid) ); //单位根 for (int R = mid << 1, j = 0; j < limit; j += R) { //R是区间的长度,j表示前已经到哪个位置了 complex w(1, 0); //幂 for (int k = 0; k < mid; k++, w = w * Wn) { //枚举左半部分 complex x = A[j + k], y = w * A[j + mid + k]; //蝴蝶效应 A[j + k] = x + y; A[j + mid + k] = x - y; } } } } int main() { int N = read(), M = read(); for (int i = 0; i <= N; i++) a[i].x = read(); for (int i = 0; i <= M; i++) b[i].x = read(); while (limit <= N + M) limit <<= 1, bit++; for (int i = 0; i < limit; i++) r[i] = ( r[i >> 1] >> 1 ) | ( (i & 1) << (bit - 1) ) ; fast_fast_tle(a, 1); fast_fast_tle(b, 1); for (int i = 0; i <= limit; i++) a[i] = a[i] * b[i]; fast_fast_tle(a, -1); for (int i = 0; i <= N + M; i++) printf("%d ", (int)(a[i].x / limit + 0.5)); return 0; }
参考链接:
(建议直接看大佬的,我都是copy过来的,整理一下思路而已)
1 https://www.cnblogs.com/zwfymqz/p/8244902.html
2. https://blog.csdn.net/enjoy_pascal/article/details/81478582