zoukankan      html  css  js  c++  java
  • 多项式模板集

    NTT && FFT

    NTT板子

    typedef long long ll;
    
    const int P = 998244353, g = 3;
    const int maxn = 1111111;
    
    int inc(int a, int b) { return (a += b) >= P ? a-P : a; }
    int qpow(int a, int b) {
    	int res = 1;
    	for (; b; a = 1ll*a*a%P, b >>= 1)
    		if (b & 1) res = 1ll*res*a%P;
    	return res;
    }
    int W[maxn << 2], inv[maxn << 2]; // 4倍空间
    void prework(int n) {
    	for (int i = 1; i < n; i <<= 1) { // 不取等
    		W[i] = 1;
    		for (int j = 1, Wn = qpow(g, (P-1)/i>>1); j < i; j++) W[i+j] = 1ll*W[i+j-1]*Wn%P; // 不取等
    	}
    	inv[1] = 1;
    	for (int i = 2; i <= n; i++) inv[i] = 1ll*(P-P/i)*inv[P%i]%P; // 取等
    }
    void ntt(int *a, int n, int opt) {
    	static int rev[maxn << 2] = {0}; // {0} 赋初值
    	for (int i = 1; i < n; i++)
    		if ((rev[i] = rev[i>>1]>>1|(i&1?n>>1:0)) > i) std::swap(a[i], a[rev[i]]);
    	for (int q = 1; q < n; q <<= 1)
    		for (int p = 0; p < n; p += q<<1)
    			for (int i = 0, t; i < q; i++)
    				t = 1ll*a[p+q+i]*W[q+i]%P, a[p+q+i] = inc(a[p+i], P-t), a[p+i] = inc(a[p+i], t);
    	if (~opt) return;
    	std::reverse(a+1, a+n);
    	for (int i = 0; i < n; i++) a[i] = 1ll*a[i]*inv[n]%P;
    }
    
    int getsize(int n) { int x = 1; while (x < n) x <<= 1; return x; }
    

    FFT板子

    // 手动定义comp
    struct comp { double x, y; };
    comp operator + (comp a, comp b) { return (comp){a.x+b.x, a.y+b.y}; }
    comp operator - (comp a, comp b) { return (comp){a.x-b.x, a.y-b.y}; }
    comp operator * (comp a, comp b) { return (comp){a.x*b.x-a.y*b.y, a.x*b.y+a.y*b.x}; }
    
    comp W[maxn << 2];
    void prework(int n) {
    	for (int i = 1; i < n; i <<= 1)
    		for (int j = 0; j < i; j++)
    			W[i+j] = (comp){cos(PI/i*j), sin(PI/i*j)};
                // 这里直接算,防止丢精度。这在MTT中很重要
    }
    
    void fft(comp *a, int n, int opt) { // 与NTT没有区别
    	static int rev[maxn << 2] = {0};
    	for (int i = 1; i < n; i++)
    		if ((rev[i] = rev[i>>1]>>1|(i&1?n>>1:0)) > i) std::swap(a[i], a[rev[i]]);
    	for (int q = 1; q < n; q <<= 1)
    		for (int p = 0; p < n; p += q<<1)
    			for (int i = 0; i < q; i++) {
    				comp t = a[p+q+i]*W[q+i]; a[p+q+i] = a[p+i]-t, a[p+i] = a[p+i]+t;
    			}
    	if (~opt) return;
    	std::reverse(a+1, a+n);
    	for (int i = 0; i < n; i++) a[i].x /= n, a[i].y /= n;
    }
    

    任意模数NTT(MTT)

    MTT事实上就是将一个大整数拆分成两部分(a imes Base+b),分开相乘最后相加就能保证精度了。直接做FFT次数很多常数很大,在%完myy的论文和代码后发现有一种只做4次FFT的方法,需要一些trick和引理。

    考虑构造(A_i=a+bi)(B_i=a-bi),如果快速求出(FFT(A))(FFT(B)),那么就能加减消元求出(FFT(a))(FFT(b))。myy论文中讲如果求出了(FFT(A)),可以利用其直接推出(FFT(B))

    原文是这样说的:

    [B(omega^k)=overline{A(omega^{-k})} ]

    其中上划线表示共轭复数

    (omega^k)幅角为( heta),注意到

    [egin{aligned} B(omega^k)&=sum_{j=0}^nB_jomega^{jk}=sum_{j=0}^n(a_j-b_ji)(cos(j heta)+isin(j heta))\ &=sum_{j=0}^nBig((a_jcos(j heta)+b_jsin(j heta))+i(a_jsin(j heta)-b_jcos(j heta))Big)\ &=sum_{j=0}^nBig((a_jcos(-j heta)-b_jsin(-j heta))-i(a_jsin(-j heta)+b_jcos(-j heta))Big)\ &=overline{sum_{j=0}^nBig((a_jcos(-j heta)-b_jsin(-j heta))+i(a_jsin(-j heta)+b_jcos(-j heta))Big)}\ &=overline{sum_{i=0}^n(a_j+b_ji)(cos(-j heta)+isin(-j heta))}\ &=overline{sum_{i=0}^nA_jomega^{-jk}}=overline{A(omega^{-k})} end{aligned}]

    运用这个性质优化可以大幅减少FFT的次数。

    void conv(int *x, int *y, int *z, int n) { // z=x*y,长度为n
    	for (int i = 0; i < n; i++) x[i] %= P, y[i] %= P; // 提前取模
    	static comp a[maxn << 2], b[maxn << 2], da[maxn << 2], db[maxn << 2], dc[maxn << 2], dd[maxn << 2];
    	for (int i = 0; i < n; i++)
    		a[i] = (comp){x[i] >> 15, x[i] & 32767}, b[i] = (comp){y[i] >> 15, y[i] & 32767}; // 将整数拆分成两部分
    	fft(a, n, 1), fft(b, n, 1);
    	for (int i = 0; i < n; i++) {
    		int j = (n-1) & (n-i);
    		static comp a1, a2, b1, b2; // 分离出x和y每个部分的插值结果
    		a1 = (a[i] + conj(a[j])) * (comp){0.5, 0};
    		a2 = (a[i] - conj(a[j])) * (comp){0, -0.5};
    		b1 = (b[i] + conj(b[j])) * (comp){0.5, 0};
    		b2 = (b[i] - conj(b[j])) * (comp){0, -0.5};
    		da[i] = a1*b1, db[i] = a1*b2, dc[i] = a2*b1, dd[i] = a2*b2;
    	}
    	for (int i = 0; i < n; i++)
    		a[i] = da[i] + db[i]*(comp){0, 1}, b[i] = dc[i] + dd[i]*(comp){0, 1}; // IDFT(x+yi)=IDFT(x)+iIDFT(y),这个将用于下文
    	fft(a, n, -1), fft(b, n, -1);
    	for (int i = 0; i < n; i++) {
    		int ax = (ll)(a[i].x+0.5)%P, ay = (ll)(a[i].y+0.5)%P, bx = (ll)(b[i].x+0.5)%P, by = (ll)(b[i].y+0.5)%P; // 一定要转化成ll(数值为2^30*n)
    		z[i] = (((ll)ax << 30) + ((ll)(ay + bx) << 15) + by) % P;
    	}
    }
    

    快速沃尔什变换(FWT)

    用于解决对下标进行位运算卷积问题的方法。即

    [c_k=sum_{ioplus j=k}a_i imes b_j ]

    其中(oplus)分别为|&^的情况。

    先考虑|FWT干这样的事情:类似于FFT,对于每一个(i),它要求出(fwt(a)_i=sum_{j|i=i}a_j),在(mathcal O(nlog n))的时间复杂度在两者之间快速变换。然后(j|i=i,k|i=iLeftrightarrow(j|k)|i=i)

    如果求出了(fwt(a))(fwt(b)),发现有

    [egin{aligned} fwt(a)_i imes fwt(b)_i&=left(sum_{j|i=i}a_j ight)left(sum_{k|i=i}b_k ight)\ &=sum_{j|i=i,k|i=i}a_jb_k=sum_{(j|k)|i=i}a_jb_k=sum_{t|i=i}sum_{j|k=i}a_jb_k=fwt(c)_iend{aligned}]

    很像系数变点值!考虑怎么变换。我们按最高位为0或1来分成两组序列(a^{[0]})(a^{[1]}),不难发现

    [fwt(a)=merge(fwt(a^{[0]}),fwt(a^{[0]})+fwt(a^{[1]})) ]

    最高位是0的序列对应最高位是1的一定是包含关系,所以右边相加。

    同理我们不难发现

    [a=merge(a^{[0]},a^{[1]}-a^{[0]}) ]

    // or
    void fwt_or(int *a, int n, int opt) {
    	for (int q = 1; q < n; q <<= 1)
    		for (int p = 0; p < n; p += q<<1)
    			for (int i = 0; i < q; i++)
    				a[p+q+i] = inc(a[p+q+i], ~opt ? a[p+i] : P-a[p+i]);
    }
    

    对于&,也满足(fwt(a)_i imes fwt(b)_i=fwt(c)_i)。同上我们也可以推导出

    [fwt(a)=merge(fwt(a^{[0]})+fwt(a^{[1]}),fwt(a^{[1]})) ]

    [a=merge(a^{[0]}-a^{[1]},a^{[1]}) ]

    对于^稍微麻烦些,定义(x otimes y)(x&y)中二进制下1的个数对2取模。有

    [(iotimes j) xor (iotimes k)=iotimes(j xor k) ]

    构造(fwt(a)_i=sum_{iotimes j=0}a_j-sum_{iotimes j=1}a_j),则

    [egin{aligned} fwt(a)_i imes fwt(b)_i&=left(sum_{iotimes j=0}a_j-sum_{iotimes j=1}a_j ight)left(sum_{iotimes j=0}b_j-sum_{iotimes j=1}b_j ight)\ &=sum_{iotimes j=0,iotimes k=0}a_jb_k-sum_{iotimes j=0,iotimes k=1}a_jb_k-sum_{iotimes j=1,iotimes k=0}a_jb_k+sum_{iotimes j=1,iotimes k=1}a_jb_k\ &=sum_{(iotimes j)xor(iotimes k)=0}a_jb_k-sum_{(iotimes j)xor(iotimes k)=1}a_jb_k\ &=sum_{iotimes(j xor k)=0}a_jb_k-sum_{iotimes(j xor k)=1}a_jb_k\ &=fwt(c)_i end{aligned} ]

    符合要求,所以能推导出

    [fwt(a)=merge(fwt(a^{[0]})+fwt(a^{[1]}),fwt(a^{[0]})-fwt(a^{[1]})) ]

    [a=merge(frac{a^{[0]}+a^{[1]}}2,frac{a^{[0]}-a^{[1]}}2) ]

    // and && xor
    void fwt_and(int *a, int n, int opt) {
    	for (int q = 1; q < n; q <<= 1)
    		for (int p = 0; p < n; p += q<<1)
    			for (int i = 0; i < q; i++)
    				a[p+i] = inc(a[p+i], ~opt ? a[p+q+i] : P-a[p+q+i]);
    }
    
    void fwt_xor(int *a, int n, int opt) {
    	for (int q = 1; q < n; q <<= 1)
    		for (int p = 0; p < n; p += q<<1)
    			for (int i = 0; i < q; i++) {
    				int t = a[p+q+i]; a[p+q+i] = inc(a[p+i], P-t); a[p+i] = inc(a[p+i], t);
    				if (opt == -1) a[p+i] = 1ll*a[p+i]*inv2%P, a[p+q+i] = 1ll*a[p+q+i]*inv2%P;
    			}
    }
    

    多项式

    #include <bits/stdc++.h>
    using std::reverse; using std::vector; using std::swap; using std::max;
    const int N = 100005, P = 998244353, inv2 = P+1>>1;
    typedef vector<int> Poly;
    typedef long long LL;
    int inc(int a, int b) { return (a += b) >= P ? a-P : a; }
    int pow(int a, int b) {
    	int t = 1;
    	for (; b; b >>= 1, a = 1LL*a*a%P)
    		if (b & 1) t = 1LL*t*a%P;
    	return t;
    }
    int W[N*4], inv[N*4];
    void prework(int n) {
    	for (int i = 1; i < n; i <<= 1)
    		for (int j = W[i] = 1, Wn = pow(3, (P-1)/i>>1); j < i; j++)
    			W[i+j] = 1LL*W[i+j-1]*Wn%P;
    	inv[1] = 1;
    	for (int i = 2; i <= n; i++) inv[i] = 1LL*(P-P/i)*inv[P%i]%P;
    }
    void fft(Poly &a, int n, int opt) {
    	a.resize(n);
    	static int rev[N*4];
    	for (int i = 1; i < n; i++)
    		if ((rev[i] = rev[i>>1]>>1|(i&1?n>>1:0)) > i) swap(a[i], a[rev[i]]);
    	for (int q = 1; q < n; q <<= 1)
    		for (int p = 0; p < n; p += q<<1)
    			for (int i = 0, t; i < q; i++)
    				t = 1LL*W[q+i]*a[p+q+i]%P, a[p+q+i] = inc(a[p+i], P-t), a[p+i] = inc(a[p+i], t);
    	if (opt) return;
    	for (int i = 0, inv = pow(n, P-2); i < n; i++) a[i] = 1LL*a[i]*inv%P;
    	reverse(a.begin()+1, a.end());
    }
    Poly poly_inv(Poly A) {
    	Poly B(1, pow(A[0], P-2)), C(2);
    	for (int L = 1; L < A.size(); L <<= 1) {
    		(C = A).resize(L*2); fft(B, L*4, 1), fft(C, L*4, 1);
    		for (int i = 0; i < L*4; i++) B[i] = (2*B[i]-1LL*B[i]*B[i]%P*C[i]%P+P)%P;
    		fft(B, L*4, 0); B.resize(L*2);
    	}
    	return B.resize(A.size()), B;
    }
    int getsz(int n) { int x = 1; while (x < n) x <<= 1; return x; }
    void fix(Poly &A) { int x = A.size(); while (x > 1 && !A[x-1]) x--; A.resize(x); }
    Poly operator + (Poly A, Poly B) {
    	A.resize(max(A.size(), B.size()));
    	for (int i = 0; i < B.size(); i++) A[i] = inc(A[i], B[i]);
    	return fix(A), A;
    }
    Poly operator - (Poly A, Poly B) {
    	A.resize(max(A.size(), B.size()));
    	for (int i = 0; i < B.size(); i++) A[i] = inc(A[i], P-B[i]);
    	return fix(A), A;
    }
    Poly operator * (int k, Poly A) {
    	for (int i = 0; i < A.size(); i++) A[i] = 1LL*k*A[i]%P;
    	return A;
    }
    Poly operator * (Poly A, Poly B) {
    	int L = getsz(A.size()+B.size()-1);
    	fft(A, L, 1), fft(B, L, 1);
    	for (int i = 0; i < L; i++) A[i] = 1LL*A[i]*B[i]%P;
    	return fft(A, L, 0), fix(A), A;
    }
    Poly operator / (Poly A, Poly B) {
    	int n = A.size()-B.size()+1;
    	reverse(A.begin(), A.end()); A.resize(n);
    	reverse(B.begin(), B.end()); B.resize(n);
    	return A = A * poly_inv(B), A.resize(n), reverse(A.begin(), A.end()), fix(A), A;
    }
    Poly operator % (Poly A, Poly B) { return A - A/B*B; }
    Poly poly_deri(Poly A) {
    	for (int i = 0; i < A.size()-1; i++) A[i] = 1LL*(i+1)*A[i+1]%P;
    	return A.resize(A.size()-1), A;
    }
    Poly poly_int(Poly A) {
    	for (int i = A.size()-1; i; i--) A[i] = 1LL*A[i-1]*inv[i]%P;
    	return A[0] = 0, A;
    }
    Poly poly_sqrt(Poly A) {
    	Poly B(1, 1), iB, C(2);
    	for (int L = 1; L < A.size(); L <<= 1) {
    		(C = A).resize(L*2); B.resize(L*2); iB = poly_inv(B);
    		fft(B, L*4, 1), fft(iB, L*4, 1), fft(C, L*4, 1);
    		for (int i = 0; i < L*4; i++) B[i] = (1LL*B[i]*B[i]+C[i])%P*iB[i]%P*inv2%P;
    		fft(B, L*4, 0); B.resize(L*2);
    	}
    	return B.resize(A.size()), B;
    }
    Poly poly_ln(Poly A) {
    	Poly B = poly_deri(A) * poly_inv(A);
    	return B.resize(A.size()), poly_int(B);
    }
    Poly poly_exp(Poly A) {
    	Poly B(1, 1), C;
    	for (int L = 1; L < A.size(); L <<= 1)
    		B.resize(L*2), C = A + Poly(1, 1) - poly_ln(B), C.resize(L*2), B = B*C;
    	return B.resize(A.size()), B;
    }
    Poly poly_pow(Poly A, int k) {
    	return poly_exp(k * poly_ln(A));
    }
    #define lc (o << 1)
    #define rc (o << 1 | 1)
    Poly Q[N*4];
    void build(int o, int l, int r, int x[]) {
    	if (l == r) { Q[o].push_back(P-x[l]), Q[o].push_back(1); return; }
    	int mid = l+r>>1;
    	build(lc, l, mid, x), build(rc, mid+1, r, x);
    	Q[o] = Q[lc] * Q[rc];
    }
    void calc(Poly A, int o, int l, int r, int x[]) {
    	if (l == r) { x[l] = A[0]; return; }
    	int mid = l+r>>1;
    	calc(A % Q[lc], lc, l, mid, x), calc(A % Q[rc], rc, mid+1, r, x);
    }
    void poly_calc(Poly A, int n, int x[]) {
    	build(1, 1, n, x); calc(A, 1, 1, n, x);
    }
    Poly inter(int o, int l, int r, int x[], int y[]) {
    	if (l == r) return Poly(1, 1LL*y[l]*pow(x[l], P-2)%P);
    	int mid = l+r>>1;
    	return inter(lc, l, mid, x, y)*Q[rc] + inter(rc, mid+1, r, x, y)*Q[lc];
    }
    Poly poly_inter(int n, int x[], int y[]) {
    	return build(1, 1, n, x), calc(poly_deri(Q[1]), 1, 1, n, x), inter(1, 1, n, x, y);
    }
    int n; Poly A;
    int main() {
    	scanf("%d", &n); A.resize(n); prework(n*2);
    	for (int i = 0; i < n; i++) scanf("%d", &A[i]);
    	A = poly_exp(A);
    	for (int i = 0; i < n; i++) printf("%d ", A[i]);
    	return 0;
    }
    
  • 相关阅读:
    指针数组与数组指针
    209. 长度最小的子数组
    面试题 05.08. 绘制直线(位运算)
    1160. 拼写单词
    88. 合并两个有序数组
    80. 删除排序数组中的重复项 II(On)
    python自定义异常和主动抛出异常
    类的构造方法1(类中的特殊方法)
    python之判断一个值是不是可以被调用
    主动调用其他类的成员(普通调用和super方法调用)
  • 原文地址:https://www.cnblogs.com/ac-evil/p/13048974.html
Copyright © 2011-2022 走看看