zoukankan      html  css  js  c++  java
  • luoguP4705 玩游戏 分治FFT


    [egin{aligned} Ans(k) &= sum limits_{i = 1}^n sum limits_{j = 1}^m sum limits_{t = 0}^k inom{k}{t} a_i^t b_j^{k - t} \ &= sum limits_{t = 0}^k inom{k}{t} (sum limits_{i = 1}^n a_i^t) (sum limits_{j = 1}^m b_i^{k - t}) \ &= k! * sum limits_{t = 0}^k (frac{sum limits_{i = 1}^n a_i^t}{t!}) (frac{sum limits_{j = 1}^m b_i^{k - t}}{(k - t)!}) \ end{aligned} ]

    右边是一个卷积,只需考虑对(t = 0, 1 ..., n)求出(f(t) = sum limits_{i = 1}^n a_i^t)


    考虑生成函数(OGF)

    [egin{aligned} F(x) &= sum limits_{i = 0}^{infty} f(i) * x^i \ &= sum limits_{i = 1}^{infty} sum limits_{j = 1}^{n} a_j^i \ &= sum limits_{i = 1}^n (sum limits_{j = 0}^{infty} (a_ix)^j) \ &= sum limits_{i = 1}^n frac{1}{1 - a_i x} end{aligned} ]

    那么,现在的问题在于如何求解

    [sum limits_{i = 1}^n frac{1}{1 - a_i x} ]


    考虑分治(FFT)
    一种很好想的思路是先求出(sum limits_{i = 1}^l frac{1}{1 - a_i x})(sum limits_{i = l + 1}^r frac{1}{1 - a_i x})
    它们一定是形如(frac{A}{B})的一个式子,不妨设左边为(frac{A}{B}),右边为(frac{C}{D})
    那么合并之后的形式为(frac{AD + BC}{BD}),然后维护即可

    复杂度(O(n log^2 n))


    可以发现$$In'(a * b) = (In(a) + In(b))' = In'(a) + In'(b)$$
    因此,我们考虑$$In'(frac{1}{1 - a_i x}) = frac{-a_i}{1 - a_i x}$$
    注意不能在(In)中添加常数因子,因此我们只能从这个形式来考虑

    [egin{aligned} G(x) &= sum limits_{i = 1}^n In'(frac{1}{1 - a_i x}) \ &= In ' (prod limits_{i = 1}^n frac{1}{1 - a_i x}) end{aligned} ]

    可以用分治(FFT)求出(G)
    观察数列

    [F(x) = a_i^0 + a_i^1 x^1 + a_i^2 x^2 + a_i^3 x^3 ... ]

    [G(x) = -a_i^1 - a_i^2 x^1 - a_i^3 x^2 + ... ]

    因此,$$-xG(x) + n = F(x)$$

    然后就可以啦


    #include <bits/stdc++.h>
    using namespace std;
    
    #define ri register int
    #define rep(io, st, ed) for(ri io = st; io <= ed; io ++)
    #define drep(io, ed, st) for(ri io = ed; io >= st; io --)
    
    const int sid = 5e5 + 5;
    const int mod = 998244353;
    
    #define gc getchar
    inline int read() {
    	int p = 0, w = 1; char c = gc();
    	while(c > '9' || c < '0') { if(c == '-') w = -1; c = gc(); }
    	while(c >= '0' && c <= '9') p = p * 10 + c - '0', c = gc();
    	return p * w;
    }
    
    inline int Inc(int a, int b) { return (a + b >= mod) ? a + b - mod : a + b; }
    inline int Dec(int a, int b) { return (a - b < 0) ? a - b + mod : a - b; }
    inline int mul(int a, int b) { return 1ll * a * b % mod; }
    inline int fp(int a, int k) {
    	int ret = 1;
    	for( ; k; k >>= 1, a = mul(a, a))
    		if(k & 1) ret = mul(ret, a);
    	return ret;
    }
    	
    int rev[sid], fac[sid], inv[sid], ivf[sid];
    int a[sid], b[sid], ak[sid], bk[sid];
    	
    inline void init(int Mn, int &n, int &lg) {
    	n = 1; lg = 0;
    	while(n < Mn) n <<= 1, lg ++;
    }
    	
    inline void NTT(int *a, int n, int opt) {
    	for(ri i = 0; i < n; i ++) 
    		if(i < rev[i]) swap(a[i], a[rev[i]]);
    	for(ri i = 1; i < n; i <<= 1)
    	for(ri j = 0, g = fp(3, (mod - 1) / (i << 1)); j < n; j += (i << 1))
    	for(ri k = j, G = 1; k < i + j; k ++, G = mul(G, g)) {
    		int x = a[k], y = mul(G, a[i + k]);
    		a[k] = (x + y >= mod) ? x + y - mod : x + y;
    		a[i + k] = (x - y < 0) ? x - y + mod : x - y;
    	}
    	if(opt == -1) {
    		reverse(a + 1, a + n);
    		int ivn = fp(n, mod - 2);
    		for(ri i = 0; i < n; i ++) a[i] = mul(a[i], ivn);
    	}
    }
    	
    int ia[sid], ib[sid];
    inline void Inv(int *a, int *b, int n) {
    	if(n == 1) { b[0] = fp(a[0], mod - 2); return; }
    	Inv(a, b, n >> 1);
    	
    	int N = 1, lg = 0; init(n + n, N, lg);
    	for(ri i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    	for(ri i = 0; i < N; i ++) ia[i] = ib[i] = 0;
    	for(ri i = 0; i < n; i ++) ia[i] = a[i], ib[i] = b[i];
    	
    	NTT(ia, N, 1); NTT(ib, N, 1);
    	for(ri i = 0; i < N; i ++) ia[i] = Dec((ib[i] << 1) % mod, mul(ia[i], mul(ib[i], ib[i])));
    	NTT(ia, N, -1);
    	
    	for(ri i = 0; i < n; i ++) b[i] = ia[i];
    }
    	
    inline void Init_Inv(int n) {
    	inv[0] = inv[1] = 1;
    	for(int i = 2; i <= n; i ++) inv[i] = mul(inv[mod % i], mod - mod / i);
    	fac[0] = fac[1] = 1;
    	for(int i = 2; i <= n; i ++) fac[i] = mul(fac[i - 1], i);
    	ivf[0] = ivf[1] = 1;
    	for(int i = 2; i <= n; i ++) ivf[i] = mul(ivf[i - 1], inv[i]);
    }
    	
    inline void wf(int *a, int *b, int n) { for(ri i = 1; i < n; i ++) b[i - 1] = mul(a[i], i); }
    inline void jf(int *a, int *b, int n) { for(ri i = 1; i < n; i ++) b[i] = mul(a[i - 1], inv[i]);}
    	
    int da[sid], iva[sid];
    inline void In(int *a, int *b, int n) {
    	for(ri i = 0; i < n + n; i ++) da[i] = iva[i] = 0; 
    	Inv(a, iva, n); wf(a, da, n);
    	
    	int N = 1, lg = 0; init(n + n, N, lg);
    	for(ri i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    	
    	NTT(da, N, 1); NTT(iva, N, 1);
    	for(ri i = 0; i < N; i ++) da[i] = mul(da[i], iva[i]);
    	NTT(da, N, -1); jf(da, b, n);
    }
    
    int hb[sid], inb[sid];
    inline void Exp(int *a, int *b, int n) {
    	if(n == 1) { b[0] = 1; return; }
    	Exp(a, b, n >> 1);
    	
    	int N = 1, lg = 0; init(n + n, N, lg);
    	for(ri i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    	
    	for(ri i = 0; i < N; i ++) inb[i] = hb[i] = 0;
    	In(b, inb, n);
    	for(ri i = 0; i < n; i ++) hb[i] = Dec(a[i], inb[i]); hb[0] ++;
    	
    	NTT(inb, N, 1); NTT(hb, N, 1);
    	for(ri i = 0; i < N; i ++) inb[i] = mul(inb[i], hb[i]);
    	NTT(inb, N, -1);
    	
    	for(ri i = 0; i < n; i ++) b[i] = inb[i];
    }
    
    int Ib[sid], F[sid * 2], pa[sid], pb[sid];
    inline void calc(int *a, int *b, int n, int t) {
    	int N = 1, lg = 0;
    	init(max(n, t) + 5, N, lg);
    	for(ri i = 0; i < (N << 1); i ++) F[i] = 0;
    	for(ri i = 0; i < n; i ++) F[2 * i] = 1, F[2 * i + 1] = mod - a[i + 1];
    	for(ri i = n; i < N; i ++) F[2 * i] = 1;
    	for(ri i = 1; i < N; i <<= 1) {
    		for(ri j = 0; j < N; j += (i << 1)) {
    			int M = 1, lg = 0;
    			init((i << 2), M, lg);
    			for(ri k = 0; k < M; k ++) rev[k] = (rev[k >> 1] >> 1) | ((k & 1) << (lg - 1));
    			for(ri k = 0; k < M; k ++) pa[k] = pb[k] = 0;
    			for(ri k = 0; k < (i << 1); k ++) 
    				pa[k] = F[(j << 1) + k], pb[k] = F[(j << 1) + (i << 1) + k];
    			NTT(pa, M, 1); NTT(pb, M, 1);
    			for(ri k = 0; k < M; k ++) pa[k] = mul(pa[k], pb[k]);
    			NTT(pa, M, -1);
    			for(ri k = 0; k < (i << 2); k ++) F[(j << 1) + k] = pa[k];
    		}
    	}
    	for(ri i = 0; i < N; i ++) Ib[i] = 0;
    	In(F, Ib, N); wf(Ib, F, N);
    	
    	b[0] = n;
    	for(ri i = 1; i <= t; i ++) b[i] = mul(mod - F[i - 1], ivf[i]);
    }
    	
    inline void solve(int n, int m, int t) {
    	int N = 1, lg = 0;
    	init(t + t + 5, N, lg);
    	for(ri i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (lg - 1));
    	
    	NTT(ak, N, 1); NTT(bk, N, 1);
    	for(ri i = 0; i < N; i ++) ak[i] = mul(ak[i], bk[i]);
    	NTT(ak, N, -1);
    		
    	int ivnm = fp(mul(n, m), mod - 2);
    	for(ri i = 1; i <= t; i ++) 
    		printf("%d
    ", mul(mul(ak[i], fac[i]), ivnm));
    }
    	
    int main() {
    	int n = read(), m = read();
    	rep(i, 1, n) a[i] = read();
    	rep(i, 1, m) b[i] = read();
    	Init_Inv(500000);
    	int t = read(); 
    	calc(a, ak, n, t); calc(b, bk, m, t);
    	solve(n, m, t);
    	return 0;
    }
    

    请无视中间的exp

  • 相关阅读:
    hdoj 2803 The MAX【简单规律题】
    hdoj 2579 Dating with girls(2)【三重数组标记去重】
    hdoj 1495 非常可乐【bfs隐式图】
    poj 1149 PIGS【最大流经典建图】
    poj 3281 Dining【拆点网络流】
    hdoj 3572 Task Schedule【建立超级源点超级汇点】
    hdoj 1532 Drainage Ditches【最大流模板题】
    poj 1459 Power Network【建立超级源点,超级汇点】
    hdoj 3861 The King’s Problem【强连通缩点建图&&最小路径覆盖】
    hdoj 1012 u Calculate e
  • 原文地址:https://www.cnblogs.com/reverymoon/p/10187388.html
Copyright © 2011-2022 走看看