zoukankan      html  css  js  c++  java
  • @loj


    @description@

    在一个 s 个点的图中,存在 s - n 条边,使图中形成了 n 个连通块,第 i 个连通块中有 (a_i) 个点。

    现在我们需要再连接 n - 1 条边,使该图变成一棵树。对一种连边方案,设原图中第 i 个连通块连出了 (d_i) 条边,那么这棵树 T 的价值为:

    [val(T) = (prod_{i=1}^{n}d_{i}^{m})(sum_{i=1}^{n}d_{i}^{m}) ]

    你的任务是求出所有可能的生成树的价值之和,对 998244353 取模。

    原题戳我

    @solution@

    @正文@

    注意到 (d_i) 为度数,那么考虑 prufer 序列,直接写出答案表达式:

    [ans = sum_{(sum_{i=1}^{n}b_i)=n-2}(frac{(n-2)!}{prod_{i=1}^{n}b_i!}) imes(prod_{i=1}^{n}a_{i}^{b_i + 1}) imes(prod_{i=1}^{n}(b_{i} + 1)^{m}) imes(sum_{i=1}^{n}(b_{i} + 1)^{m}) ]

    其中 (b_i + 1 = d_i)

    作一些简单的变形:

    [ans = (n-2)! imes(prod_{i=1}^{n}a_i) imessum_{i=1}^{n}sum_{(sum_{j=1}^{n}b_j)=n-2}(frac{(b_{i} + 1)^{2m} imes a_{i}^{b_{i}}}{b_{i}!}) imes(prod_{j=1,j ot =i}^{n}frac{(b_{j} + 1)^{2m} imes a_{j}^{b_{j}}}{b_{j}!}) ]

    引入生成函数。如果记 (P(x) = sum_{i=0}frac{(i + 1)^{2m} imes x^i}{i!})(Q(x) = sum_{i=0}frac{(i + 1)^{m} imes x^i}{i!}),则:

    [ans = (n-2)! imes(prod_{i=1}^{n}a_i) imes([x^{n-2}]sum_{i=1}^{n}P(a_i x) imes(prod_{j=1,j ot =i}^{n}Q(a_j x)))\ ans = (n-2)! imes(prod_{i=1}^{n}a_i) imes([x^{n-2}]prod_{i=1}^{n}Q(a_i x) imessum_{i=1}^{n}frac{P(a_i x)}{Q(a_i x)})]

    注意到 (frac{P(a_i x)}{Q(a_i x)}) 其实就是 (frac{P(x)}{Q(x)}) 的第 k 项乘上 (a_i^{k})

    也就是说 (sum_{i=1}^{n}frac{P(a_i x)}{Q(a_i x)}) 就是 (frac{P(x)}{Q(x)}) 的第 k 项乘上 (sum_{i=1}^{n}a_i^{k}),而 (sum_{i=1}^{n}a_i^{k}) 是可以快速求出的(在补充部分介绍)。

    尝试把 (prod_{i=1}^{n}Q(a_i x)) 也化成加法形式:利用对数,可以得到 (prod_{i=1}^{n}Q(a_i x) = exp(sum_{i=1}^{n}ln(Q(a_i x))))

    之后就没有了。只要求出了 (sum_{i=1}^{n}a_i^{k}),剩下的都是模板。

    @补充@

    关于如何求 (sum_{i=1}^{n}a_i^{k}),其实方法比较多,这里介绍一种:

    注意到 (ln(1 - x) = -sum_{i=1}frac{x^i}{i}),那么只要求出 (sum_{i=1}^{n}ln(1 - a_ix)),也就求出了 (sum_{i=1}^{n}a_i^{k})

    利用对数的性质,有 (sum_{i=1}^{n}ln(1 - a_ix) = ln(prod_{i=1}^{n}(1 - a_ix)))

    然后里面那个式子分治 fft 可以 O(nlog^2n) 搞定,这样一来总时间复杂度其实就是 O(nlog^2n)。

    @accepted code@

    #include <cstdio>
    #include <algorithm>
    using namespace std;
    
    const int MAXN = 4*30000;
    const int MOD = 998244353;
    
    struct mint{
    	int x;
    	mint(int _x=0) : x(_x) {}
    	friend mint operator + (mint a, mint b) {
    		return a.x + b.x >= MOD ? a.x + b.x - MOD : a.x + b.x;
    	}
    	friend mint operator - (mint a, mint b) {
    		return a.x - b.x < 0 ? a.x - b.x + MOD : a.x - b.x;
    	}
    	friend mint operator * (mint a, mint b) {
    		return (int)(1LL * a.x * b.x % MOD);
    	}
    	friend mint pow_mod(mint b, int p) {
    		mint ret = 1;
    		while( p ) {
    			if( p & 1 ) ret *= b;
    			b *= b;
    			p >>= 1;
    		}
    		return ret;
    	}
    	friend mint operator / (mint a, mint b) {
    		return a * pow_mod(b, MOD - 2);
    	}
    	friend void operator += (mint &a, mint b) {a = a + b;}
    	friend void operator -= (mint &a, mint b) {a = a - b;}
    	friend void operator *= (mint &a, mint b) {a = a * b;}
    	friend void operator /= (mint &a, mint b) {a = a / b;}
    };
    
    namespace poly{
    	const mint G = 3;	
    	mint w[20], iw[20], inv[MAXN + 5];
    	void init() {
    		for(int i=0;i<20;i++) {
    			w[i] = pow_mod(G, (MOD-1)/(1<<i));
    			iw[i] = pow_mod(w[i], MOD-2);
    		}
    		inv[1] = 1;
    		for(int i=2;i<=MAXN;i++)
    			inv[i] = 0 - (MOD/i)*inv[MOD%i];
    	}
    	void debug(mint *A, int n) {
    		for(int i=0;i<n;i++)
    			printf("%d ", A[i].x);
    		puts("");
    	}
    	void ntt(mint *A, int n, int type) {
    		for(int i=0,j=0;i<n;i++) {
    			if( i < j ) swap(A[i], A[j]);
    			for(int k=(n>>1);(j^=k)<k;k>>=1);
    		}
    		for(int i=1;(1<<i)<=n;i++) {
    			int s = (1 << i), t = (s >> 1);
    			mint u = (type == 1 ? w[i] : iw[i]);
    			for(int j=0;j<n;j+=s) {
    				mint p = 1;
    				for(int k=0;k<t;k++,p*=u) {
    					mint x = A[j+k], y = A[j+k+t];
    					A[j+k] = x + y*p, A[j+k+t] = x - y*p;
    				}
    			}
    		}
    		if( type == -1 ) {
    			mint iv = inv[n];
    			for(int i=0;i<n;i++)
    				A[i] *= iv;
    		}
    	}
    	int length(int n) {
    		int len; for(len = 1; len < n; len <<= 1);
    		return len;
    	}
    	void pcopy(mint *A, mint *B, int n, int l) {
    		for(int i=0;i<n;i++) A[i] = B[i];
    		for(int i=n;i<l;i++) A[i] = 0;
    	}
    	mint t1[MAXN + 5], t2[MAXN + 5];
    	void pmul(mint *A, int nA, mint *B, int nB, mint *C) {
    		int len = length(nA + nB - 1);
    		pcopy(t1, A, nA, len), ntt(t1, len, 1);
    		pcopy(t2, B, nB, len), ntt(t2, len, 1);
    		for(int i=0;i<len;i++) C[i] = t1[i] * t2[i];
    		ntt(C, len, -1);
    	}
    	mint t3[MAXN + 5], t4[MAXN + 5];
    	void pinv(mint *A, mint *B, int n) {
    		if( n == 1 ) {
    			B[0] = 1 / A[0];
    			return ;
    		}
    		int m = (n + 1) >> 1; pinv(A, B, m);
    		int len = length(n << 1);
    		pcopy(t3, A, n, len), ntt(t3, len, 1);
    		pcopy(t4, B, m, len), ntt(t4, len, 1);
    		for(int i=0;i<len;i++) B[i] = t4[i]*(2 - t3[i]*t4[i]);
    		ntt(B, len, -1);
    	}
    	void pdif(mint *A, mint *B, int n) {
    		for(int i=1;i<n;i++)
    			B[i-1] = A[i] * i;
    	}
    	void pint(mint *A, mint *B, int n) {
    		for(int i=n-1;i>=0;i--)
    			B[i+1] = A[i] / (i + 1);
    		B[0] = 0;
    	}
    	mint t5[MAXN + 5], t6[MAXN + 5];
    	void pln(mint *A, mint *B, int n) {
    		pinv(A, t5, n), pdif(A, t6, n);
    		pmul(t5, n, t6, n, B);
    		pint(B, B, n);
    	}
    	mint t7[MAXN + 5], t8[MAXN + 5];
    	void pexp(mint *A, mint *B, int n) {
    		if( n == 1 ) {
    			B[0] = 1;
    			return ;
    		}
    		int m = (n + 1) >> 1; pexp(A, B, m);
    		int len = length(n << 1);
    		pcopy(t7, B, m, len), pln(t7, t8, n), pcopy(t7, t8, n, len);
    		pcopy(t8, B, m, len);
    		for(int i=0;i<n;i++) t7[i] = A[i] - t7[i];
    		t7[0] = t7[0] + 1;
    		ntt(t7, len, 1), ntt(t8, len, 1);
    		for(int i=0;i<len;i++) B[i] = t7[i] * t8[i];
    		ntt(B, len, -1);
    	}
    }
    
    int n, m, k;
    
    mint A[MAXN + 5], B[MAXN + 5];
    void init() {
    	mint t = 1;
    	for(int i=0;i<n;i++,t*=i) {
    		mint a = 1 / t, b = pow_mod(mint(i + 1), m);
    		A[i] = a * b * b, B[i] = a * b;
    	}
    	poly::init();
    }
    
    mint a[MAXN + 5], f[MAXN + 5], s[MAXN + 5];
    int solve(int l, int r) {
    	if( l == r ) {
    		f[l<<1] = 1, f[l<<1|1] = 0 - a[l];
    		return 2;
    	}
    	int mid = (l + r) >> 1;
    	int ls = solve(l, mid), rs = solve(mid + 1, r);
    	poly::pmul(f + (l<<1), ls, f + ((mid + 1) << 1), rs, f + (l << 1));
    	return ls + rs - 1;
    }
    void get_pow_sum() {
    	solve(0, n - 1), poly::pln(f, s, n + 1);
    	s[0] = n;
    	for(int i=1;i<=n;i++)
    		s[i] = 0 - s[i]*i;
    }
    
    mint t1[MAXN + 5], t2[MAXN + 5];
    int main() {
    	scanf("%d%d", &n, &m), k = n - 2, init();
    	for(int i=0;i<n;i++) scanf("%d", &a[i].x);
    	
    	get_pow_sum();
    	poly::pln(B, t1, n);
    	for(int i=0;i<n;i++)
    		t1[i] *= s[i];
    	poly::pexp(t1, t2, n);
    	poly::pinv(B, t1, n);
    	poly::pmul(A, n, t1, n, t1);
    	for(int i=0;i<n;i++)
    		t1[i] *= s[i];
    	poly::pmul(t1, n, t2, n, t1);
    	mint ans = t1[n - 2];
    	for(int i=0;i<n;i++) ans *= a[i];
    	for(int i=1;i<=n-2;i++) ans *= i;
    	printf("%d
    ", ans.x);
    }
    

    @details@

    顺带一提,这道题还有依赖于斯特林数的 O(nmlogn) 的做法(但是我看不懂 QaQ)。

  • 相关阅读:
    2019年度SAP项目实践计划
    实现祖国统一其实并不难
    2018年终总结之摄影作品展
    2018年终总结之访问量较大的原创文章
    2018年终总结之AI领域开源框架汇总
    2018 AI产业界大盘点
    为什么我觉得Python烂的要死?
    SAP MM 根据采购订单反查采购申请?
    2018-8-10-win10-uwp-ApplicationView
    2018-8-10-WPF-播放-gif
  • 原文地址:https://www.cnblogs.com/Tiw-Air-OAO/p/12119237.html
Copyright © 2011-2022 走看看