zoukankan      html  css  js  c++  java
  • @loj


    @description@

    本题包含三个问题:

    问题 0:已知两棵 n 个结点的树的形态(两棵树的结点标号均为 1~n),其中第一棵树是红树,第二棵树是蓝树。要给予每个结点一个 [1, y] 中的整数,使得对于任意两个节点 p, q,如果存在一条路径 (a1 = p, a2, ..., am = q) 同时属于这两棵树,则 p, q 必须被给予相同的数。求给予数的方案数。

    问题 1:已知蓝树,对于红树的所有 (n^{n-2}) 种选择方案,求问题 0 的答案之和。

    问题 2:对于蓝树的所有 (n^{n-2}) 种选择方案,求问题 1 的答案之和。

    原题请戳我查看qwq

    @solution@

    说点人话,若两棵树边集的交集为 S,则答案等于 (y^{n - |S|})

    前排提醒:下面可能会出现类似 (1 - y) 作分母的情况,当 y = 1 时没有意义。所以需要优先特判掉。
    注意 y = 1 时 |S| 并不会影响,所以只取决于有多少种可能的情况。

    @问题 0@

    相信大家都会做。

    @问题 1@

    不难想到一个指数级的思路:枚举交集 S,记 f(S) 表示满足要求的树的个数。

    交集恰好为 S 显然不好做,而且看起来很好容斥。我们枚举 T,计算交集包含 T 的情况,记为 g(T)。
    稍微思考一下得到容斥式子 (f(S) = sum_{Ssubseteq T}(-1)^{|T|-|S|}g(T))

    则最终答案有如下式子:

    [ans = sum_{S}f(S) imes y^{n - |S|}\ = sum_{S}sum_{Ssubseteq T}(-1)^{|T| - |S|}g(T) imes y^{n - |S|}]

    尝试消去 S:

    [ans = y^nsum_{T}g(T)sum_{Ssubseteq T}(-1)^{|T|-|S|}y^{-|S|} \ = y^nsum_{T}g(T)sum_{i=0}^{|T|}C_{|T|}^{i}(-1)^{|T|-i}y^{-i}]

    用一个二项式定理就可以得到 (ans = y^nsum_{T}g(T)(y^{-1} - 1)^{|T|})
    不妨先记 (u = (y^{-1} - 1)),则 (ans = y^nsum_{T}g(T)u^{|T|})

    尽管如此还是一个指数级算法。考虑 g(T) 应该怎么求,然后优化成多项式算法。
    如果给定边集 T,只要另一棵树中包含 T 中这些边即可。因此相当于先用 T 中的边将 1~n 的点连成 k 个大小为 a1, a2, ..., ak 的连通块,然后再连成一棵树的方案数。
    用 matrix-tree / prufer 可以证明这个方案数为 (g(T) = n^{k-2} imesprod_{i=1}^{k}a_i)(证明详见下面的补充部分)。

    由于 T 中的边连成的连通块个数 (k = n-|T|),所以将原式进一步改写为:

    [ans = y^nsum_{T}(n^{k-2} imesprod_{i=1}^{k}a_i imes u^{n-k}) \ = frac{y^n imes u^n}{n^2} imessum_{T}(prod_{i=1}^{k}(a_i imes n imes u^{-1}))]

    可以作 O(n^2) 的树形 dp:记 dp[i][j] 表示以 i 为根的子树被分成了若干连通块,其中 i 所在的连通块大小为 j,其他连通块的总贡献为 dp[i][j]。

    当然可以更简单:考虑 (a_i imes n imes u^{-1}) 的组合意义。即大小为 (a_i) 的连通块中选择一个,贡献 (n imes u^{-1})
    然后记 dp[0/1][i] 表示 i 所在的连通块是否有点贡献了 (n imes u^{-1}),这样子就是 O(n) 的树形 dp 了。

    @问题 2@

    如果你像我一开始一样,从上面的 dp[0/1][i] 入手,最后就会陷入两个生成函数互相卷积的怪圈中,只能分治 fft O(nlog^2n) 求解。。。

    考虑依然是容斥,其它过程都与上面一样,只是 g(T) 的计算式子变为 (g(T) = (n^{k-2} imesprod_{i=1}^{k}a_i)^2)(因为要枚举两棵树嘛)。

    那么最终答案 (h[n] = frac{y^n imes u^n}{n^4} imessum_{T}(prod_{i=1}^{k}(a_i^2 imes n^2 imes u^{-1})))

    现在枚举 T 反而不好办了。我们考虑直接枚举序列 a,算出有多少边集 T。不妨令点 1 所在的连通块大小为 a1,枚举与点 1 在同一连通块的点得到 h 的转移:

    [h[n] = sum_{i=0}^{n-1}C_{n-1}^{i} imes (i+1)^2 imes n^2 imes u^{-1} imes (i+1)^{i-1} imes h[n-i-1] ]

    上面那个可以直接 O(n^2) 做了。不过还可以进一步优化:
    (p[i] = (i+1)^2 imes n^2 imes u^{-1} imes (i+1)^{i-1}),则上面的卷积又可以写作 (h[n+1] = sum_{i=0}^{n}C_{n}^{i} imes p[i] imes h[n-i])

    这是一个经典的卷积式子,可以写成指数型生成函数然后求多项式 exp(具体可见下面的补充部分)。
    时间复杂度 O(nlogn)。

    @补充部分@

    对上面所提到的两个问题的细节补充。

    (1)1~n 的点连成 k 个大小为 a1, a2, ..., ak 的连通块,然后再连成一棵树的方案数为 (n^{k-2} imesprod_{i=1}^{k}a_i)
    证明我选择的是 prufer 序列(懒得写matrix-tree的矩阵证法,网上应该找得到)

    由于一个数在 prufer 序列中的出现次数为它的度数减一,又因为从某个大小为 ai 的连通块连出去一条边有 ai 种选择,所以有:

    [ans = sum_{sum_{i=1}^{k}d_i = 2k-2}frac{(k-2)!}{prod_{i=1}^{k}(d_i-1)!}prod_{i=1}^{k}a_i^{d_i}\ =prod_{i=1}^{k}a_i imes sum_{sum_{i=1}^{k}(d_i-1) = k-2}frac{(k-2)!}{prod_{i=1}^{k}(d_i-1)!}prod_{i=1}^{k}a_i^{d_i-1}\ =prod_{i=1}^{k}a_i imes (sum_{i=1}^{k}a_i)^{k-2} = prod_{i=1}^{k}a_i imes n^{k-2}]

    关于后面那个怎么来的,其实是逆用多项式的展开:

    [(x_1 + x_2 + dots + x_n)^k = sum_{sum_{i=1}^{n}a_i = k}(frac{k!}{prod_{i=1}^{n}a_i!}prod_{i=1}^{n}x_i^{a_i}) ]

    (2)关于指数型生成函数的 exp 对应的卷积意义。
    首先要认识到,对于指数型生成函数而言,积分相等于右移,求导相当于左移。
    假如令 (f(x) = sum_{i=0}frac{a_{i}}{i!}x^i),则 (f'(x) = sum_{i=0}frac{a_{i+1}}{i!}x^i)(int f(x) = sum_{i=1}frac{a_{i-1}}{i!}x^i)

    根据求导法则,有 (ln(f(x))' = frac{f'(x)}{f(x)}),即 (ln(f(x))' imes f(x) = f'(x))

    如果记 (g(x) = ln(f(x))' = sum_{i=0}frac{b_{i}}{i!}x^i),比较第 n 项等式两边的系数,可以得到:

    [sum_{i=0}^{n}frac{a_i}{i!} imesfrac{b_{n-i}}{(n-i)!} = frac{a_{n+1}}{n!} ]

    然后可以推出 (a_{n+1} = sum_{i=0}^{n}C_n^i imes a_i imes b_{n-i}),就是我们题目中的卷积式子。

    @accepted code@

    #include <set>
    #include <cstdio>
    #include <iostream>
    #include <algorithm>
    using namespace std;
    
    const int MOD = 998244353;
    const int MAXN = 400000;
    
    struct mint{
    	int x;
    	mint(int _x = 0) : x(_x) {}
    	friend mint operator + (mint a, const mint &b) {return (a.x + b.x) % MOD;}
    	friend mint operator - (mint a, const mint &b) {return (a.x + MOD - b.x) % MOD;}
    	friend mint operator * (mint a, const mint &b) {return 1LL * a.x * b.x % MOD;}
    	friend void operator += (mint &a, const mint &b) {a = a + b;}
    	friend void operator -= (mint &a, const mint &b) {a = a - b;}
    	friend void operator *= (mint &a, const mint &b) {a = a * b;}
    	friend mint mpow(mint b, int p) {
    		if( b.x == 1 ) return 1;
    		mint ret = 1;
    		while( p ) {
    			if( p & 1 ) ret = ret * b;
    			b = b * b;
    			p >>= 1;
    		}
    		return ret;
    	}
    	friend mint operator / (mint a, const mint &b) {return a * mpow(b, MOD - 2);}
    	friend void operator /= (mint &a, const mint &b) {a = a / b;}
    };
    
    int n, y, op;
    
    void solve0() {
    	if( op == 0 ) printf("%d
    ", 1);
    	else if( op == 1 ) printf("%d
    ", mpow((mint)n, n - 2).x);
    	else printf("%d
    ", mpow((mint)n, 2*(n - 2)).x);
    }
    
    set<pair<int, int> >e;
    void solve1() {
    	int ans = 0;
    	for(int i=1;i<n;i++) {
    		int u, v; scanf("%d%d", &u, &v);
    		if( u > v ) swap(u, v);
    		e.insert(make_pair(u, v));
    	}
    	for(int i=1;i<n;i++) {
    		int u, v; scanf("%d%d", &u, &v);
    		if( u > v ) swap(u, v);
    		if( e.count(make_pair(u, v)) ) ans++;
    	}
    	printf("%d
    ", mpow((mint)y, n - ans).x);
    }
    
    struct edge{
    	edge *nxt; int to;
    }edges[2*MAXN + 5], *adj[MAXN + 5], *ecnt = edges;
    void addedge(int u, int v) {
    	edge *p = (++ecnt);
    	p->to = v, p->nxt = adj[u], adj[u] = p;
    	p = (++ecnt);
    	p->to = u, p->nxt = adj[v], adj[v] = p;
    }
    mint dp[2][MAXN + 5], del;
    void dfs(int x, int f) {
    	dp[0][x] = 1, dp[1][x] = del;
    	for(edge *p=adj[x];p;p=p->nxt) {
    		if( p->to == f ) continue;
    		dfs(p->to, x);
    		dp[1][x] = dp[1][x] * dp[1][p->to] + dp[1][x] * dp[0][p->to] + dp[0][x] * dp[1][p->to];
    		dp[0][x] = dp[0][x] * dp[1][p->to] + dp[0][x] * dp[0][p->to];
    	}
    }
    void solve2() {
    	for(int i=1;i<n;i++) {
    		int u, v; scanf("%d%d", &u, &v);
    		addedge(u, v);
    	}
    	mint u = 1; u = (u - y) / y;
    	mint p = mpow(y * u, n) / n / n;
    	del = n / u, dfs(1, 0);
    	printf("%d
    ", (dp[1][1] * p).x);
    }
    
    namespace poly{
    	const mint G = 3;
    	mint w[20], iw[20], inv[MAXN + 5];
    	void init() {
    		inv[1] = 1;
    		for(int i=2;i<=MAXN;i++)
    			inv[i] = MOD - inv[MOD%i]*(MOD/i);
    		for(int i=0;i<20;i++)
    			w[i] = mpow(G, (MOD-1)/(1<<i)), iw[i] = 1 / w[i];
    	}
    	void ntt(mint *A, int n, int type) {
    		for(int i=0,j=0;i<n;i++) {
    			if( i < j ) swap(A[i], A[j]);
    			for(int k=(n>>1);(j^=k)<k;k>>=1);
    		}
    		for(int i=1;(1<<i)<=n;i++) {
    			int s = (1 << i), t = (s >> 1);
    			mint u = (type == 1 ? w[i] : iw[i]);
    			for(int j=0;j<n;j+=s) {
    				mint p = 1;
    				for(int k=0;k<t;k++,p*=u) {
    					mint x = A[j + k], y = A[j + k + t];
    					A[j + k] = x + y*p, A[j + k + t] = x - y*p;
    				}
    			}
    		}
    		if( type == -1 ) {
    			for(int i=0;i<n;i++)
    				A[i] *= inv[n];
    		}
    	}
    	mint t1[MAXN + 5], t2[MAXN + 5];
    	int length(int n) {
    		int l; for(l = 1; l < n; l <<= 1);
    		return l;
    	}
    	void mul(mint *A, int nA, mint *B, int nB, mint *C) {
    		int nC = (nA + nB - 1), len = length(nC);
    		for(int i=0;i<nA;i++) t1[i] = A[i];
    		for(int i=nA;i<len;i++) t1[i] = 0;
    		for(int i=0;i<nB;i++) t2[i] = B[i];
    		for(int i=nB;i<len;i++) t2[i] = 0;
    		ntt(t1, len, 1), ntt(t2, len, 1);
    		for(int i=0;i<len;i++) C[i] = t1[i] * t2[i];
    		ntt(C, len, -1);
    	}
    	mint t3[MAXN + 5], t4[MAXN + 5];
    	void pinv(mint *A, mint *B, int n) {
    		if( n == 1 ) {
    			B[0] = 1 / A[0];
    			return ;
    		}
    		int m = (n + 1) >> 1;
    		pinv(A, B, m);
    		int len = length(n << 1);
    		for(int i=0;i<m;i++) t3[i] = B[i];
    		for(int i=m;i<len;i++) t3[i] = 0;
    		for(int i=0;i<n;i++) t4[i] = A[i];
    		for(int i=n;i<len;i++) t4[i] = 0;
    		ntt(t3, len, 1), ntt(t4, len, 1);
    		for(int i=0;i<len;i++)
    			B[i] = t3[i] * (2 - t3[i] * t4[i]);
    		ntt(B, len, -1);
    	}
    	void pdif(mint *A, mint *B, int n) {
    		for(int i=1;i<n;i++)
    			B[i-1] = A[i] * i;
    	}
    	void pint(mint *A, mint *B, int n) {
    		for(int i=n-1;i>=0;i--)
    			B[i+1] = A[i] * inv[i + 1];
    		B[0] = 0;
    	}
    	mint t5[MAXN + 5], t6[MAXN + 5];
    	void ln(mint *A, mint *B, int n) {
    		pdif(A, t5, n), pinv(A, t6, n);
    		mul(t5, n - 1, t6, n, t5);
    		pint(t5, B, n);
    	}
    	mint t7[MAXN + 5], t8[MAXN + 5];
    	void exp(mint *A, mint *B, int n) {
    		if( n == 1 ) {
    			B[0] = 1;
    			return ;
    		}
    		int m = (n + 1) >> 1;
    		exp(A, B, m);
    		for(int i=0;i<m;i++) t7[i] = B[i];
    		for(int i=m;i<n;i++) t7[i] = 0;
    		ln(t7, t8, n);
    		for(int i=0;i<n;i++) t7[i] = A[i] - t8[i];
    		t7[0].x += 1;
    		for(int i=0;i<m;i++) t8[i] = B[i];
    		mul(t7, n, t8, m, B);
    	}
    }
    
    mint fct[MAXN + 5], ifct[MAXN + 5];
    void init() {
    	poly::init(); fct[0] = 1;
    	for(int i=1;i<=MAXN;i++) fct[i] = fct[i-1] * i;
    	ifct[MAXN] = 1 / fct[MAXN];
    	for(int i=MAXN-1;i>=0;i--) ifct[i] = ifct[i+1] * (i+1);
    }
    /*
    mint comb(int n, int m) {
    	return fct[n] * ifct[m] * ifct[n-m];
    }
    */
    mint f[MAXN + 5], g[MAXN + 5];
    void solve3() {
    	init();
    	mint u = 1; u = (u - y) / y;
    	mint p = mpow(y * u, n) / (mint(n) * n * n * n);
    	del = n / u * n;
    /*
    	for(int i=0;i<n;i++)
    		g[i] = mpow(mint(i+1), i-1) * del * mint(i+1) * mint(i+1);
    	f[0] = 1;
    	for(int i=0;i<n;i++)
    		for(int j=0;j<=i;j++)
    			f[i+1] += comb(i, j)*g[i-j]*f[j];
    	printf("%d
    ", (f[n] * p).x);
    */
    	for(int i=0;i<n;i++)
    		g[i] = mpow(mint(i+1), i-1) * del * mint(i+1) * mint(i+1), g[i] *= ifct[i];
    	poly::pint(g, g, n);
    	poly::exp(g, f, n + 1);
    	printf("%d
    ", (f[n] * p * fct[n]).x);
    }
    
    int main() {
    	scanf("%d%d%d", &n, &y, &op);
    	if( y == 1 ) solve0();
    	else if( op == 0 ) solve1();
    	else if( op == 1 ) solve2();
    	else if( op == 2 ) solve3();
    }
    

    @details@

    讲道理,这道题并不算太难分析。

    不过可以学到很多分析组合计数的知识与技巧。

  • 相关阅读:
    【Java】推断文件的后缀名
    UVa 131
    Java开发手冊 Java学习手冊教程(MtJava开发手冊)
    《Java并发编程实战》第十五章 原子变量与非堵塞同步机制 读书笔记
    OC语言--NSFileManager&amp; NSFileHandle
    为什么我刚发表的文章变成了“待审核”,csdn有没有官方解释啊
    mac os升级为 Yosemite 10.10 后不能创建javaproject
    【spring】在spring cloud项目中使用@ControllerAdvice做自定义异常拦截,无效 解决原因
    【mybatis】mybatis动态order by 的问题, 注意 只需要把#{} 改成 ${} 即可
    【spring cloud】一个ms微服务想要给注册中心eureka发现,需要满足这些条件,微服务不能被eureka注册中心发现的解决方案
  • 原文地址:https://www.cnblogs.com/Tiw-Air-OAO/p/12092766.html
Copyright © 2011-2022 走看看