网上相关博客不少,这里给自己留个带点注释的模板,以后要是忘了作提醒用。
以洛谷3803多项式乘法裸题为例。
FFT:
1 #include <cstdio> 2 #include <cmath> 3 #include <cctype> 4 #include <algorithm> 5 #define ri readint() 6 #define gc getchar() 7 8 int readint() { 9 int x = 0, s = 1, c = gc; 10 while (c <= 32) c = gc; 11 if (c == '-') s = -1, c = gc; 12 for (; isdigit(c); c = gc) x = x * 10 + c - 48; 13 return x * s; 14 } 15 16 const int maxn = 4 * 1e6 + 10; 17 const double PI = acos(-1.0); 18 19 struct Complex { 20 double x, y; 21 Complex(double a = 0, double b = 0):x(a), y(b){} 22 }; 23 Complex operator + (Complex A, Complex B) { return Complex(A.x + B.x, A.y + B.y); } 24 Complex operator - (Complex A, Complex B) { return Complex(A.x - B.x, A.y - B.y); } 25 Complex operator * (Complex A, Complex B) { return Complex(A.x * B.x - A.y * B.y, A.x * B.y + A.y * B.x); } 26 27 Complex a[maxn], b[maxn]; 28 int n, m; 29 int r[maxn], l, limit = 1; 30 31 void fft(Complex *A, int type) { 32 for (int i = 0; i < limit; i++) 33 if (i < r[i]) 34 std::swap(A[i], A[r[i]]); 35 //迭代方式模拟递归写法,需要理解递归是怎么做的才能看懂这个 36 for (int mid = 1; mid < limit; mid <<= 1) { 37 //本来单位根是2*PI/len,这里len替换成2*mid,2就约掉了 38 Complex Wn(cos(PI / mid), type * sin(PI / mid)); 39 for (int R = mid << 1, j = 0; j < limit; j += R) { 40 Complex w(1, 0);//单位根的k次幂 41 for (int k = 0; k < mid; k++, w = w * Wn) { 42 //蝴蝶变换 43 Complex x = A[j+k], y = w * A[j+k+mid]; 44 A[j+k] = x + y; 45 A[j+k+mid] = x - y; 46 } 47 } 48 } 49 } 50 51 int main() { 52 n = ri, m = ri; 53 for (int i = 0; i <= n; i++) 54 a[i].x = ri; 55 for (int i = 0; i <= m; i++) 56 b[i].x = ri; 57 58 while (limit <= n + m) {//长度变为2^l 59 limit <<= 1; 60 l++; 61 } 62 for (int i = 0; i < limit; i++)//二进制镜像 63 r[i] = (r[i>>1] >> 1) | ((i&1) << (l-1)); 64 fft(a, 1); 65 fft(b, 1); 66 for (int i = 0; i < limit; i++) 67 a[i] = a[i] * b[i]; 68 fft(a, -1); 69 for (int i = 0; i <= n + m; i++) 70 printf("%d ", (int)(a[i].x / limit + 0.5)); 71 return 0; 72 }
NTT是用模域取代了复数域,性质相同只是换了单位根,所以板子基本相同。我这两个相比NTT确实比FFT快一点的:
1 #include <bits/stdc++.h> 2 #define ll long long 3 #define ri readll() 4 #define gc getchar() 5 #define rep(i, a, b) for (int i = a; i <= b; i++) 6 using namespace std; 7 8 const int P = 998244353, G = 3, Gi = 332748118, maxn = 4 * 1e6 + 5; 9 //P的原根为3,3%P的逆元为332748118 10 //原根意味着:3^(P-1) % P = 1,其中P-1是3%P的阶,本应是φ(P),这里恰好为大素数 11 ll n, m; 12 ll a[maxn], b[maxn]; 13 int limit = 1, l, r[maxn]; 14 15 ll readll() { 16 ll x = 0ll, s = 1ll; 17 char c = gc; 18 while (c <= 32) c = gc; 19 if (c == '-') s = -1ll, c = gc; 20 for (; isdigit(c); c = gc) x = x * 10 + c - 48; 21 return x * s; 22 } 23 24 ll ksm(ll a, ll b, int mod) { 25 ll res = 1ll; 26 for (; b; b >>= 1) { 27 if (b & 1) res = res * a % mod; 28 a = a * a % mod; 29 } 30 return res; 31 } 32 33 void NTT(ll *A, int flag) { 34 rep(i, 0, limit) 35 if (i < r[i]) 36 swap(A[i], A[r[i]]); 37 38 for (int mid = 1; mid < limit; mid <<= 1) { 39 //如果是变换则单位根为3^[(P-1)/(len)] % P,逆变换则用逆元 40 ll Wn = ksm(flag ? G : Gi, (P-1) / (mid*2), P); 41 for (int R = mid << 1, j = 0; j < limit; j += R) { 42 ll w = 1ll; 43 for (int k = 0; k < mid; k++, w = w * Wn % P) { 44 ll x = A[j+k], y = A[j+k+mid] * w % P; 45 A[j+k] = (x + y) % P; 46 A[j+k+mid] = (x - y + P) % P; 47 } 48 } 49 } 50 } 51 52 int main() { 53 n = ri, m = ri; 54 rep(i, 0, n) a[i] = (ri + P) % P; 55 rep(i, 0, m) b[i] = (ri + P) % P; 56 57 while (limit < n + m + 1) { 58 limit <<= 1; 59 l++; 60 } 61 rep(i, 0, limit) r[i] = (r[i>>1] >> 1) | ((i & 1) << (l - 1)); 62 NTT(a, 1); NTT(b, 1); 63 rep(i, 0, limit) a[i] = a[i] * b[i] % P; 64 NTT(a, 0); 65 66 ll inv = ksm(limit, P - 2, P);//最后变换回来要乘长度的逆元 67 rep(i, 0, n + m) printf("%lld ", a[i] * inv % P); 68 69 return 0; 70 }