zoukankan      html  css  js  c++  java
  • 「JSOI2019」神经网络(容斥+组合计数+背包dp)

    Address

    luogu5333

    loj3102

    Solution

    容易发现,一条哈密顿回路本质上就是:把每棵树都拆成若干条有向路径,再把所有的有向路径连接成环,环上的相邻两条有向路径不可以来自同一棵树。

    先求出 \(g_{i,j}\) 表示把第 \(i\) 棵树拆成 \(j\) 条有向路径的方案数。

    考虑 \(\text{dp}\),记 \(f_{u,i,0/1/2/3}\) 分别表示:\(u\) 的子树拆成 \(i\) 条路径,\(u\) 是路径起点,是路径终点,单点成路径,既不是路径起点也不是路径终点的方案数。

    注意 \(f_{u,i,0/1}\) 不允许 \(u\) 单点成路径。转移随便讨论一下即可。最终 \(g_{i,j}=f_{u,i,0}+f_{u,i,1}+f_{u,i,2}+f_{u,i,3}\)

    接下来,假设我们对于所有的 \(i∈[1,m]\),已经确定第 \(i\) 棵树拆成 \(a_i\) 条路径,那么如何计算答案呢?

    考虑容斥。枚举第 \(i\) 棵树的路径在环上被划分为至多 \(b_i\) 段。我们钦定第 \(m\) 棵树的 \(1\) 号节点所在的路径为环上第一条路径,那么在此条件下的方案数为:

    \[(\prod_{i=1}^mC_{a_i-1}^{b_i-1})×(\prod_{i=1}^{m-1}g_{i,a_i}×a_i!)×g_{m,a_m}×(a_m-1)!×\frac{(\sum_{i=1}^{m-1}b_i)!}{\prod_{i=1}^{m-1}b_i!}×C_{(\sum_{i=1}^nb_i)-2}^{b_m-1} \]

    其中 \(\prod_{i=1}^mC_{a_i-1}^{b_i-1}\) 表示把每棵树的路径划分成 \(i\) 段的方案数,\(g_{i,a_i}×a_i!\) 表示在第 \(i\) 棵树上选出 \(a_i\) 条路径形成一个排列的方案数。

    如果确定了选出的路径形成的排列是哪些,划分成的 \(b_i\) 段分别是什么,问题就转化为:第 \(i\) 种颜色的球有 \(b_i\) 个,同种颜色的球之间没有区别,求把所有的求串成环,使得相邻两个异色的方案数。

    钦定环上第一条路径相当于钦定第一个球的颜色一定为 \(m\),最后一个球的颜色不是 \(m\),然后断环为链。所以先把前 \(m-1\) 种颜色的球排好(方案数为),最后在中间的 \((\sum b_i)-2\) 个位置中选出 \(b_m-1\) 个位置放颜色为 \(m\) 的球,剩下的空位给其它的球即可。

    加上容斥系数之后对答案的贡献即

    \[(\prod_{i=1}^mC_{a_i-1}^{b_i-1})×(-1)^{\sum_{i=1}^ma_i-b_i}×(\prod_{i=1}^{m-1}g_{i,a_i}×a_i!)×g_{m,a_m}×(a_m-1)!×\frac{(\sum_{i=1}^{m-1}b_i)!}{\prod_{i=1}^{m-1}b_i!}×C_{(\sum_{i=1}^nb_i)-2}^{b_m-1} \]

    可是直接枚举所有 \(a_i,b_i\) 的复杂度是指数级的,考虑优化。

    记第 \(i\) 棵树的点数为 \(cnt_i\)。对于第 \(i\) 棵树(\(i∈[1,m-1]\)),枚举 \(j=a_i,k=b_i\),可以写出这样的一个多项式:

    \[\sum_{j=1}^{cnt_i}g_{i,j}×i!\sum_{k=1}^jC_{j-1}^{k-1}×(-1)^{j-k}×\frac{x^k}{k!} \]

    然后我们先把这 \(m-1\) 个多项式相乘,再把 \(x^k\) 的系数乘上 \(k!\)。这样得到的 \(x^k\) 的系数 \(A_k\) 相当于:对于 \(i∈[1,m-1]\),枚举所有 \(a_i,b_i\) 满足 \(\sum_{i=1}^{m-1}b_i=k\),然后把 $$(\prod_{i=1}mC_{a_i-1}{b_i-1})×(-1){\sum_{i=1}ma_i-b_i}×(\prod_{i=1}{m-1}g_{i,a_i}×a_i!)×\frac{(\sum_{i=1}{m-1}b_i)!}{\prod_{i=1}^{m-1}b_i!}$$ 计入 \(A_k\)

    接下来写出第 \(m\) 个多项式,记 \(B_k\) 表示这个多项式第 \(k\) 项的系数:

    \[\sum_{j=1}^{cnt_m}g_{m,j}×(j-1)!×\sum_{k=1}^jC_{j-1}^{k-1}×(-1)^{j-k}×x^k \]

    最后枚举 \(k=\sum_{i=1}^{m-1}b_i\),再枚举 \(j=b_m\),把 \(A_k×B_j×C_{k+j-2}^{j-1}\) 计入答案即可。

    时间复杂度 \(O((\sum cnt_i)^2)\)

    Code

    #include <bits/stdc++.h>
    
    using namespace std;
    
    #define ll long long
    
    template <class t>
    inline void read(t & res)
    {
    	char ch;
    	while (ch = getchar(), !isdigit(ch));
    	res = ch ^ 48;
    	while (ch = getchar(), isdigit(ch))
    		res = res * 10 + (ch ^ 48);
    }
    
    const int e = 5005, o = 305, mod = 998244353;
    
    int g[o][e], f[e][e][4], sze[e], m, n, ans, tmp[e][4], cnt, pre[e], h[e], p[o][e];
    int adj[e], nxt[e << 1], go[e << 1], num, tot[o], fac[e], inv[e], a[o][e];
    
    inline void add(int &x, int y)
    {
    	(x += y) >= mod && (x -= mod);
    }
    
    inline int plu(int x, int y)
    {
    	add(x, y);
    	return x;
    }
    
    inline int mul(int x, int y)
    {
    	return (ll)x * y % mod;
    }
    
    inline int ksm(int x, int y)
    {
    	int res = 1;
    	while (y)
    	{
    		if (y & 1) res = mul(res, x);
    		y >>= 1;
    		x = mul(x, x);
    	}
    	return res;
    }
    
    inline void link(int x, int y)
    {
    	nxt[++num] = adj[x]; adj[x] = num; go[num] = y;
    	nxt[++num] = adj[y]; adj[y] = num; go[num] = x;
    }
    
    inline int c(int x, int y)
    {
    	if (x < y) return 0;
    	if (x == y) return 1;
    	return mul(fac[x], mul(inv[y], inv[x - y]));
    }
    
    inline void clear()
    {
    	int i, j, k; num = ans = 0;
    	for (i = 1; i <= n; i++)
    	{
    		adj[i] = 0;
    		for (j = 1; j <= n; j++)
    			for (k = 0; k <= 3; k++)
    				f[i][j][k] = 0;
    	}
    }
    
    inline void dfs(int u, int pa)
    {
    	sze[u] = f[u][1][2] = 1;
    	int i, j, k, l;
    	for (i = adj[u]; i; i = nxt[i])
    	{
    		int v = go[i];
    		if (v == pa) continue;
    		dfs(v, u);
    		for (j = 1; j <= sze[u] + sze[v]; j++)
    			for (k = 0; k <= 3; k++)
    				tmp[j][k] = f[u][j][k], f[u][j][k] = 0;
    		for (j = 1; j <= sze[u]; j++)
    			for (k = 1; k <= sze[v]; k++)
    			{
    				int s = plu(f[v][k][0], f[v][k][2]);
    				add(f[u][j + k - 1][0], mul(tmp[j][2], s));
    				add(f[u][j + k - 1][3], mul(tmp[j][1], s));
    				
    				s = plu(f[v][k][1], f[v][k][2]);
    				add(f[u][j + k - 1][1], mul(tmp[j][2], s));
    				add(f[u][j + k - 1][3], mul(tmp[j][0], s));
    				
    				s = 0;
    				for (l = 0; l <= 3; l++)
    					add(s, f[v][k][l]);
    				for (l = 0; l <= 3; l++)
    					add(f[u][j + k][l], mul(tmp[j][l], s));
    			}
    		sze[u] += sze[v];
    	}
    }
    
    inline void solve(int k)
    {
    	read(n); clear();
    	cnt += n; tot[k] = n; pre[k] = tot[k] + pre[k - 1];
    	int i, x, y, j;
    	for (i = 1; i < n; i++)
    		read(x), read(y), link(x, y);
    	dfs(1, 0);
    	for (i = 1; i <= n; i++)
    		for (j = 0; j <= 3; j++)
    			add(g[k][i], f[1][i][j]);
    }
    
    int main()
    {
    	read(m);
    	int i, j, k;
    	for (i = 1; i <= m; i++)
    		solve(i);
    		
    	fac[0] = 1;
    	for (i = 1; i <= cnt; i++)
    		fac[i] = mul(fac[i - 1], i);
    	inv[cnt] = ksm(fac[cnt], mod - 2);
    	for (i = cnt - 1; i >= 0; i--)
    		inv[i] = mul(inv[i + 1], i + 1);
    		
    	for (i = 1; i < m; i++)
    		for (j = 1; j <= tot[i]; j++)
    			g[i][j] = mul(g[i][j], fac[j]);
    		
    	for (j = 1; j <= tot[m]; j++)
    		g[m][j] = mul(g[m][j], fac[j - 1]);
    		
    	for (i = 1; i < m; i++)
    		for (j = 1; j <= tot[i]; j++)
    			for (k = 1; k <= j; k++)
    			{
    				int v = mul(g[i][j], c(j - 1, k - 1));
    				v = mul(v, inv[k]);
    				if ((j - k) & 1) add(p[i][k], mod - v);
    				else add(p[i][k], v);
    			}
    	
    	for (j = 1; j <= tot[m]; j++)
    		for (k = 1; k <= j; k++)
    		{
    			int v = mul(g[m][j], c(j - 1, k - 1));
    			if ((j - k) & 1) add(p[m][k], mod - v);
    			else add(p[m][k], v);
    		}
    			
    	a[0][0] = 1;
    	for (i = 1; i < m; i++)
    		for (j = 1; j <= pre[i]; j++)
    			for (k = 1; k <= j && k <= tot[i]; k++)
    				add(a[i][j], mul(a[i - 1][j - k], p[i][k]));
    				
    	for (i = 1; i <= pre[m - 1]; i++)
    		for (j = 1; j <= tot[m]; j++)
    		{
    			int x = mul(a[m - 1][i], fac[i]), y = p[m][j];
    			add(ans, mul(x, mul(y, c(i + j - 2, j - 1)))); 
    		}
    	cout << ans << endl;
    	return 0;
    }
    
  • 相关阅读:
    tcp为什么要三次握手
    TCP/IP协议(一)网络基础知识
    拜占庭将军问题深入探讨
    Block Manager
    Standalone 集群部署
    Spark内存管理
    Checkpoint & cache & persist
    Python——在Python中如何使用Linux的epoll
    网络编程——C10K简述
    网络编程——The C10K Problem(C10K = connection 10 kilo 问题)。k 表示 kilo,即 1000
  • 原文地址:https://www.cnblogs.com/cyf32768/p/12196010.html
Copyright © 2011-2022 走看看