zoukankan      html  css  js  c++  java
  • FFT & NTT

    FFT 快速傅里叶变换

    (O(nlogn)) 计算多项式乘法

    参考博客

    系数表示法 转换为 点值表示法

    [omega_n^k = cos(dfrac {2picdot k} n) + i cdot sin(dfrac {2pi cdot k} n) ]

    [A(x)=a_0+a_1*x+a_2*{x^2}+a_3*{x^3}+a_4*{x^4}+a_5*{x^5}+\ dots+a_{n-2}*x^{n-2}+a_{n-1}*x^{n-1} ]

    [A(x)=(a_0+a_2*{x^2}+a_4*{x^4}+dots+a_{n-2}*x^{n-2})+\(a_1*x+a_3*{x^3}+a_5*{x^5}+ dots+a_{n-1}*x^{n-1}) ]

    [A_1(x)=a_0+a_2*{x}+a_4*{x^2}+dots+a_{n-2}*x^{frac{n}{2}-1} ]

    [A_2(x)=a_1+a_3*{x}+a_5*{x^2}+ dots+a_{n-1}*x^{frac{n}{2}-1} ]

    [A(x)=A_1(x^2)+xA_2(x^2) ]

    带入 (x = omega_n^k)

    [A(omega_n^k) = A_1(omega_{frac n2}^k) + omega_n^kA_2(omega_{frac n2}^k) ]

    带入 $x = omega_n^{k+frac n2} $

    [A(omega_n^{k+frac n2}) = A_1(omega_{frac n2}^k) -omega_n^kA_2(omega_{frac n2}^k) ]

    也就是说如果知道了 $A_1(x),A_2(x) $ 分别在 (omega_{frac n2}^0) , (omega_{frac n2}^1) , (omega_{frac n2}^2) ,...,(omega_{frac n2}^{frac n2 -1}) 的取值,

    就可以 (O(n)) 的求出 (A(x))

    void fft(cp *a,int n,int inv)//inv是取共轭复数的符号
    {
        if (n==1)return;
        int mid=n/2;
        static cp b[MAXN];
        for(int i = 0;i < mid;i++)b[i]=a[i*2],b[i+mid]=a[i*2+1];
        
        for(int i = 0;i < n;i++)a[i]=b[i];
        fft(a,mid,inv),fft(a+mid,mid,inv);//分治
        
        for(int i = 0;i < mid;i++)
        {
            cp x(cos(2*pi*i/n),inv*sin(2*pi*i/n));//inv取决是否取共轭复数
            b[i]=a[i]+x*a[i+mid],b[i+mid]=a[i]-x*a[i+mid];
        }
        for(int i = 0;i < a;i++)a[i]=b[i];
    }
    

    每个位置分治后最终的位置是二进制翻转后的位置

    void fft(cp *a,int n,int inv)
    {
        int bit=0;
        while ((1<<bit)<n)bit++;
        fo(i,0,n-1)
        {
            rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
            if (i<rev[i])swap(a[i],a[rev[i]]);//不加这条if会交换两次(就是没交换)
        }
        for (int mid=1;mid<n;mid*=2)//mid是准备合并序列的长度的二分之一
        {
        	cp temp(cos(pi/mid),inv*sin(pi/mid));//单位根,pi的系数2已经约掉了
            for (int i=0;i<n;i+=mid*2)//mid*2是准备合并序列的长度,i是合并到了哪一位
    		{
                cp omega(1,0);
                for (int j=0;j<mid;j++,omega*=temp)//只扫左半部分,得到右半部分的答案
                {
                    cp x=a[i+j],y=omega*a[i+j+mid];
                    a[i+j]=x+y,a[i+j+mid]=x-y;//这个就是蝴蝶变换什么的
                }
            }
        }
    }
    

    洛谷模板

    注意 lim

    #include<bits/stdc++.h>
    using namespace std;
    
    const double pi = acos(-1.0);
    const int N = 3e6 + 10;
    
    struct cp {
    	double x, y;
    	cp() {}
    	cp(double _x, double _y) {
    		x = _x; y = _y;
    	}
    	cp operator + (cp b) {
    		return cp(x + b.x, y + b.y);
    	}
    	cp operator -(cp b) {
    		return cp(x - b.x, y - b.y);
    	}
    	cp operator *(cp b) {
    		return cp(x * b.x - y * b.y, x * b.y + y * b.x);
    	}
    };
    int rev[N];
    int bit = 0;
    int lim;
    void FFT(cp* a, int inv) {
    	
    	for (int i = 0; i < lim; i++) {
    		if (i < rev[i]) {
    			swap(a[i], a[rev[i]]);
    		}
    	}
    	
    	for (int mid = 1; mid < lim; mid <<= 1) {
    		cp temp(cos(pi / mid), inv * sin(pi / mid));
    		for (int i = 0; i < lim; i += mid * 2) {
    			cp omega(1, 0);
    			for (int j = 0; j < mid; j++, omega = omega * temp) {
    				cp x = a[i + j], y = omega * a[i + j + mid];
    				a[i + j] = x + y, a[i + j + mid] = x - y;
    			}
    		}
    	}
    }
    
    int n, m;
    
    cp A[N], B[N];
    
    int main() {
    	scanf("%d%d", &n, &m);
    	
    	lim = 1;
    	while (lim <= n + m)lim<<=1,bit++;//调整至 2^k
    
    	for (int i = 0; i < lim; i++) {
    		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    	}
    	for (int i = 0; i <= n; i++)scanf("%lf", &A[i].x), A[i].y = 0;
    	for (int i = 0; i <= m; i++)scanf("%lf", &B[i].x), B[i].y = 0;
    
    	FFT(A, 1);
    	FFT(B, 1);
    	for (int i = 0; i <= lim; i++) {
    		A[i] = A[i] * B[i];
    	}
    	FFT(A, -1);
    	for (int i = 0; i <= n + m; i++) {
    		printf("%d ", int(A[i].x /lim+0.5));
    	}
    
    }
    

    NTT

    参考博客

    原根

    还没有整太明白

    待补,丢一个板子

    #include<bits/stdc++.h>
    #define swap(a,b) (a^=b,b^=a,a^=b)
    using namespace std;
    
    #define LL long long 
    const int MAXN = 3 * 1e6 + 10, P = 998244353, G = 3, Gi = 332748118;
    char buf[1 << 21], * p1 = buf, * p2 = buf;
    
    int N, M, limit = 1, L, r[MAXN];
    LL a[MAXN], b[MAXN];
    inline LL fastpow(LL a, LL k) {
    	LL base = 1;
    	while (k) {
    		if (k & 1) base = (base * a) % P;
    		a = (a * a) % P;
    		k >>= 1;
    	}
    	return base % P;
    }
    inline void NTT(LL* A, int type) {
    	for (int i = 0; i < limit; i++)
    		if (i < r[i]) swap(A[i], A[r[i]]);
    	for (int mid = 1; mid < limit; mid <<= 1) {
    		LL Wn = fastpow(type == 1 ? G : Gi, (P - 1) / (mid << 1));
    		for (int j = 0; j < limit; j += (mid << 1)) {
    			LL w = 1;
    			for (int k = 0; k < mid; k++, w = (w * Wn) % P) {
    				int x = A[j + k], y = w * A[j + k + mid] % P;
    				A[j + k] = (x + y) % P,
    					A[j + k + mid] = (x - y + P) % P;
    			}
    		}
    	}
    }
    int main() {
    	scanf("%d%d", &N, &M);
    	for (int i = 0; i <= N; i++) scanf("%d", a + i);
    	for (int i = 0; i <= M; i++) scanf("%d", b + i);
    
    	while (limit <= N + M) limit <<= 1, L++;
    	for (int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
    	NTT(a, 1); NTT(b, 1);
    	for (int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % P;
    	NTT(a, -1);
    	LL inv = fastpow(limit,	 P - 2);
    	for (int i = 0; i <= N + M; i++)
    		printf("%d ", (a[i] * inv) % P);
    	return 0;
    }
    
  • 相关阅读:
    基于MFC的Media Player播放器的制作(1---播放器界面的布局)
    Codeforces 1182
    Codeforces 1169
    Codeforces 1167
    Codeforces 1166
    Codeforces 1148
    *Codeforces 1162
    Codeforces 1159
    点分治
    高斯消元*
  • 原文地址:https://www.cnblogs.com/sduwh/p/13775590.html
Copyright © 2011-2022 走看看