题目链接:Garden of Eden
题意:给定一颗n个节点的树,每个节点有一种颜色,颜色有k种,求树上有多少条路径包含这k种颜色,n<=50000,k<=10
思路:树上路径问题,用点分治求解,又由于k<=10,所以可以用二进制状态表示一条路径上包含的颜色集合,比如状态8转换成二进制为1000,那么就表示状态8表示的路径上含有第3种颜色(颜色标号从0开始)
那么考虑过重心向下的某一条路径,假设这条路径的二进制状态为d,设s表示含有所有颜色的集合,即s=(1<<k)-1,那么我们只需要找到二进制状态为s^d的集合的超集与d配对即可,所以需要求超集的和
// 超集和 for (int j = 0; j < k; j++) for (int i = s; i >= 0; i--) if (((1 << j) & i) == 0) tot[i] += tot[i | (1 << j)]; // 子集和 for (int j = 0; j < k; j++) for (int i = 0; i <= s; i++) if ((1 << j) & i) tot[i] ++ tot[i ^ (1 << j)];
求出超集后之后,按照点分治的一般步骤求解即可
#include <iostream> #include <algorithm> #include <cstring> #include <cstdio> using namespace std; typedef long long ll; const int N = 50010; const int M = 1050; struct node { int to, nex; }; node edge[2 * N]; int n, k, cnt, rt, sum, s, c, val[N], d[N]; int head[N], sz[N], son[N], vis[N], tot[M]; ll res; inline void add_edge(int u, int v) { edge[++cnt].to = v; edge[cnt].nex = head[u]; head[u] = cnt; } void dfs(int u, int fa) { sz[u] = 1; son[u] = 0; for (int i = head[u]; 0 != i; i = edge[i].nex) { int v = edge[i].to; if (v == fa || vis[v]) continue; dfs(v, u); sz[u] += sz[v]; son[u] = max(son[u], sz[v]); } son[u] = max(son[u], sum - sz[u]); if (son[u] < son[rt]) rt = u; } void init() { memset(vis, 0, sizeof(vis)); memset(head, 0, sizeof(head)); cnt = 0; res = 0; s = (1 << k) - 1; } void deep(int u, int fa, int now) { d[++c] = now; tot[now]++; for (int i = head[u]; 0 != i; i = edge[i].nex) { int v = edge[i].to; if (vis[v] || v == fa) continue; deep(v, u, now | val[v]); } } ll calc(int u, int now) { c = 0; memset(tot, 0, sizeof(tot)); deep(u, 0, now); for (int j = 0; j < k; j++) for (int i = s; i >= 0; i--) if (((1 << j) & i) == 0) tot[i] += tot[i | (1 << j)]; ll r = 0; for (int i = 1; i <= c; i++) r += tot[d[i] ^ s]; return r; } void solve(int u) { res += calc(u, val[u]); vis[u] = 1; for (int i = head[u]; 0 != i; i = edge[i].nex) { int v = edge[i].to; if (vis[v]) continue; res -= calc(v, val[u] | val[v]); sum = sz[v]; rt = 0; dfs(v, -1); solve(rt); } } int main() { while (scanf("%d%d", &n, &k) != EOF) { init(); for (int i = 1; i <= n; i++) { int a; scanf("%d", &a); val[i] = 1 << (a - 1); } for (int i = 1; i <= n - 1; i++) { int u, v; scanf("%d%d", &u, &v); add_edge(u, v); add_edge(v, u); } rt = 0; sum = son[0] = n; dfs(1, -1); solve(rt); printf("%lld ", res); } return 0; }