zoukankan      html  css  js  c++  java
  • FFT与NTT的模板

    网上相关博客不少,这里给自己留个带点注释的模板,以后要是忘了作提醒用。

    洛谷3803多项式乘法裸题为例。

    FFT:

     1 #include <cstdio>
     2 #include <cmath>
     3 #include <cctype>
     4 #include <algorithm>
     5 #define ri readint()
     6 #define gc getchar()
     7 
     8 int readint() {
     9     int x = 0, s = 1, c = gc;
    10     while (c <= 32)    c = gc;
    11     if (c == '-')    s = -1, c = gc;
    12     for (; isdigit(c); c = gc)    x = x * 10 + c - 48;
    13     return x * s;
    14 }
    15 
    16 const int maxn = 4 * 1e6 + 10;
    17 const double PI = acos(-1.0);
    18 
    19 struct Complex {
    20     double x, y;
    21     Complex(double a = 0, double b = 0):x(a), y(b){}
    22 };
    23 Complex operator + (Complex A, Complex B) { return Complex(A.x + B.x, A.y + B.y); }
    24 Complex operator - (Complex A, Complex B) { return Complex(A.x - B.x, A.y - B.y); }
    25 Complex operator * (Complex A, Complex B) { return Complex(A.x * B.x - A.y * B.y, A.x * B.y + A.y * B.x); }
    26 
    27 Complex a[maxn], b[maxn];
    28 int n, m;
    29 int r[maxn], l, limit = 1;
    30 
    31 void fft(Complex *A, int type) {
    32     for (int i = 0; i < limit; i++)
    33         if (i < r[i])
    34             std::swap(A[i], A[r[i]]);
    35     //迭代方式模拟递归写法,需要理解递归是怎么做的才能看懂这个
    36     for (int mid = 1; mid < limit; mid <<= 1) {
    37         //本来单位根是2*PI/len,这里len替换成2*mid,2就约掉了
    38         Complex Wn(cos(PI / mid), type * sin(PI / mid));
    39         for (int R = mid << 1, j = 0; j < limit; j += R) {
    40             Complex w(1, 0);//单位根的k次幂
    41             for (int k = 0; k < mid; k++, w = w * Wn) {
    42                 //蝴蝶变换
    43                 Complex x = A[j+k], y = w * A[j+k+mid];
    44                 A[j+k] = x + y;
    45                 A[j+k+mid] = x - y;
    46             }
    47         }
    48     }
    49 }
    50 
    51 int main() {
    52     n = ri, m = ri;
    53     for (int i = 0; i <= n; i++)
    54         a[i].x = ri;
    55     for (int i = 0; i <= m; i++)
    56         b[i].x = ri;
    57 
    58     while (limit <= n + m) {//长度变为2^l
    59         limit <<= 1;
    60         l++;
    61     }
    62     for (int i = 0; i < limit; i++)//二进制镜像
    63         r[i] = (r[i>>1] >> 1) | ((i&1) << (l-1));
    64     fft(a, 1);
    65     fft(b, 1);
    66     for (int i = 0; i < limit; i++)
    67         a[i] = a[i] * b[i];
    68     fft(a, -1);
    69     for (int i = 0; i <= n + m; i++)
    70         printf("%d ", (int)(a[i].x / limit + 0.5));
    71     return 0;
    72 }

     NTT是用模域取代了复数域,性质相同只是换了单位根,所以板子基本相同。我这两个相比NTT确实比FFT快一点的:

     1 #include <bits/stdc++.h>
     2 #define ll long long
     3 #define ri readll()
     4 #define gc getchar()
     5 #define rep(i, a, b) for (int i = a; i <= b; i++)
     6 using namespace std;
     7 
     8 const int P = 998244353, G = 3, Gi = 332748118, maxn = 4 * 1e6 + 5;
     9 //P的原根为3,3%P的逆元为332748118
    10 //原根意味着:3^(P-1) % P = 1,其中P-1是3%P的阶,本应是φ(P),这里恰好为大素数
    11 ll n, m;
    12 ll a[maxn], b[maxn];
    13 int limit = 1, l, r[maxn];
    14 
    15 ll readll() {
    16     ll x = 0ll, s = 1ll;
    17     char c = gc;
    18     while (c <= 32)    c = gc;
    19     if (c == '-')    s = -1ll, c = gc;
    20     for (; isdigit(c); c = gc)    x = x * 10 + c - 48;
    21     return x * s;
    22 }
    23 
    24 ll ksm(ll a, ll b, int mod) {
    25     ll res = 1ll;
    26     for (; b; b >>= 1) {
    27         if (b & 1)    res = res * a % mod;
    28         a = a * a % mod;
    29     }
    30     return res;
    31 }
    32 
    33 void NTT(ll *A, int flag) {
    34     rep(i, 0, limit)
    35     if (i < r[i])
    36         swap(A[i], A[r[i]]);
    37 
    38     for (int mid = 1; mid < limit; mid <<= 1) {
    39         //如果是变换则单位根为3^[(P-1)/(len)] % P,逆变换则用逆元
    40         ll Wn = ksm(flag ? G : Gi, (P-1) / (mid*2), P);
    41         for (int R = mid << 1, j = 0; j < limit; j += R) {
    42             ll w = 1ll;
    43             for (int k = 0; k < mid; k++, w = w * Wn % P) {
    44                 ll x = A[j+k], y = A[j+k+mid] * w % P;
    45                 A[j+k] = (x + y) % P;
    46                 A[j+k+mid] = (x - y + P) % P;
    47             }
    48         }
    49     }
    50 }
    51 
    52 int main() {
    53     n = ri, m = ri;
    54     rep(i, 0, n)    a[i] = (ri + P) % P;
    55     rep(i, 0, m)    b[i] = (ri + P) % P;
    56 
    57     while (limit < n + m + 1) {
    58         limit <<= 1;
    59         l++;
    60     }
    61     rep(i, 0, limit)    r[i] = (r[i>>1] >> 1) | ((i & 1) << (l - 1));
    62     NTT(a, 1);    NTT(b, 1);
    63     rep(i, 0, limit)    a[i] = a[i] * b[i] % P;
    64     NTT(a, 0);
    65 
    66     ll inv = ksm(limit, P - 2, P);//最后变换回来要乘长度的逆元
    67     rep(i, 0, n + m)    printf("%lld ", a[i] * inv % P);
    68 
    69     return 0;
    70 }
  • 相关阅读:
    Silverlight工具荟萃
    微软WindowsPhone7份额已经超过了Symbian
    WPF性能优化经验总结和整理综合帖
    长期提供WindowsPhone7培训 & HTML5培训 & Silverlight培训 & WPF培训
    微软首推msnNOW服务 网络社交化风暴愈演愈烈
    cppunit在vs2008下使用的环境搭建(下)
    [转]ruby中gets 和 gets.chomp 区别
    大四中软实习笔记20130226
    [转]Ruby中require、load和include区别
    大四中软实习笔记20130227
  • 原文地址:https://www.cnblogs.com/AlphaWA/p/10271241.html
Copyright © 2011-2022 走看看