P3349 [ZJOI2016]小星星
声明:本博客所有题解都参照了网络资料或其他博客,仅为博主想加深理解而写,如有疑问欢迎与博主讨论✧。٩(ˊᗜˋ)و✧*。
题目描述
小 (Y) 是一个心灵手巧的女孩子,她喜欢手工制作一些小饰品。她有 (n) 颗小星星,用 (m) 条彩色的细线串了起来,每条细线连着两颗小星星。
有一天她发现,她的饰品被破坏了,很多细线都被拆掉了。这个饰品只剩下了 (n-1) 条细线,但通过这些细线,这颗小星星还是被串在一起,也就是这些小星星通过这些细线形成了树。小 (Y) 找到了这个饰品的设计图纸,她想知道现在饰品中的小星星对应着原来图纸上的哪些小星星。如果现在饰品中两颗小星星有细线相连,那么要求对应的小星星原来的图纸上也有细线相连。小 (Y) 想知道有多少种可能的对应方式。
只有你告诉了她正确的答案,她才会把小饰品做为礼物送给你呢。
输入格式
第一行包含个 (2) 正整数 (n,m),表示原来的饰品中小星星的个数和细线的条数。接下来 (m) 行,每行包含 (2) 个正整数 (u,v),表示原来的饰品中小星星 (u) 和 (v) 通过细线连了起来。这里的小星星从 (1) 开始标号。保证 (u≠v),且每对小星星之间最多只有一条细线相连。接下来 (n-1) 行,每行包含个 (2) 正整数 (u,v),表示现在的饰品中小星星 (u) 和 (v) 通过细线连了起来。保证这些小星星通过细线可以串在一起。(n<=17,m<=n*(n-1)/2)
输出格式
输出共 (1) 行,包含一个整数表示可能的对应方式的数量。如果不存在可行的对应方式则输出 (0)。
(Solution)
从一个暴力开始考虑,设 (f[i][j][S]) 表示,以 (i) 为树根的子树,当 (i) 映射到 (j) 上,映射的子集为 (S) 时的方案数。
不断合并 (i) 的子树来转移,复杂度貌似是 (O(3^n n)) 的,但是不知道为什么 (emm)
考虑如何优化
如果我们去除 (S) 这一维,那么答案中会有一些不合法的情况,但这些不合法的情况仅为多个点映射到同一个点,不会出现两个点之间没有边相连的情况
于是我们可以做容斥,去除掉这些不合法情况
首先枚举删除一个点,在剩下的图上做一遍 (f[i][j]) 的 (dp),此时找到的状态至少有两个点映射在了同一个点上(因为图上只有 (n - 1) 个点,而树上有 (n) 个)
我们把这些状态都删除,但是会发现,会重复删除掉三个点,四个点等映射在同一个点上的情况,于是再加上它们
然后就可以做容斥了
(Code)
代码其实挺简单的,但是要注意在做 (dp) 的时候,把儿子循环放外面,先做儿子的 (dp) 再枚举 (j),不然会超时(不过这个逻辑应该挺简单,我真是降智了)
#include<bits/stdc++.h>
#define ll long long
#define F(i, x, y) for(int i = x; i <= y; ++i)
using namespace std;
int read();
const int N = 400;
int n, m;
int mp[N][N];
int head[N], cnt, ver[N], next[N];
ll ban[N], f[N][N];
ll ans;
void add(int x, int y)
{
ver[++ cnt] = y, next[cnt] = head[x], head[x] = cnt;
}
void dfs(int x, int fa)
{
F(j, 1, n) f[x][j] = 1;
for(int i = head[x]; i; i = next[i])
if(ver[i] != fa)
{
dfs(ver[i], x);
F(j, 1, n)
if(! ban[j])
{
ll res = 0;
F(v, 1, n) if(! ban[v]) res += f[ver[i]][v] * mp[v][j];
f[x][j] *= res;
}
}
}
int main()
{
n = read(), m = read();
for(int i = 1, u, v; i <= m; ++ i) u = read(), v = read(), mp[u][v] = mp[v][u] = 1;
for(int i = 1, u, v; i <= n - 1; ++ i) u = read(), v = read(), add(u, v), add(v, u);
F(k, 0, (1 << n) - 1) // 可以直接枚举状态,不用按顺序,这样简单很多
{
F(i, 1, n) ban[i] = 0;
F(i, 1, n) F(j, 1, n) f[i][j] = 0;
int size = 0; ll res = 0;
for(int i = 1, s = k; s; ++ i, s >>= 1) if(s & 1) ban[i] = 1, ++ size;
dfs(1, 0);
F(i, 1, n) res += f[1][i];
if(size & 1) ans -= res;
else ans += res;
}
printf("%lld
", ans);
return 0;
}
int read()
{
int x = 0, f = 1;
char c = getchar();
while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}