分析一下题意,我们发现最终局面中同种颜色一定全部处于同一联通块中。
把点向它对应的颜色连边,对于每一种颜色,求出它的虚树,颜色对虚树中的点连边。
连边 (i o j) 表示 (j) 颜色需要并入 (i) 颜色中。
然后跑一遍 ( ext{Tarjan}) ,此时我们发现只有初度为 (0) 的强连通分量可以直接用来合并,维护每一个强连通分量中表示颜色的点的个数(权值),答案就是所有出度为 (0) 的强连通分量的权值的最小值减一(都向同一个颜色合并)。
参考代码:
#include <algorithm>
#include <vector>
#include <cstdio>
using namespace std;
template < class T > void read(T& s) {
s = 0; int f = 0; char c = getchar();
while ('0' > c || c > '9') f |= c == '-', c = getchar();
while ('0' <= c && c <= '9') s = s * 10 + c - 48, c = getchar();
s = f ? -s : s;
}
const int _ = 4e6 + 5;
int tot, head1[_], head2[_]; struct Edge { int v, nxt; } edge[_ << 2];
void Add_edge(int* head, int u, int v) { edge[++tot] = (Edge) { v, head[u] }, head[u] = tot; }
int n, k, c[_], lg[_]; vector < int > vec[_];
int pos[_], dep[_], fa[20][_], cnt, st[20][_];
int col, co[_], top, stk[_], num, dfn[_], low[_];
int val[_], dgr[_]; vector < int > scc[_];
void dfs(int u, int f) {
dep[u] = dep[f] + 1, pos[u] = ++pos[0];
for (int i = 1; i <= 18; ++i) {
fa[i][u] = fa[i - 1][fa[i - 1][u]];
if (fa[i][u] != 0) {
st[i][u] = ++cnt;
Add_edge(head2, st[i][u], st[i - 1][u]);
Add_edge(head2, st[i][u], st[i - 1][fa[i - 1][u]]);
}
}
for (int i = head1[u]; i; i = edge[i].nxt) {
int v = edge[i].v; if (v == f) continue ;
fa[0][v] = u, st[0][v] = ++cnt;
Add_edge(head2, st[0][v], v);
Add_edge(head2, st[0][v], u);
dfs(v, u);
}
}
int LCA(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
for (int i = 18; ~i; --i)
if (dep[fa[i][x]] >= dep[y]) x = fa[i][x];
if (x == y) return x;
for (int i = 18; ~i; --i)
if (fa[i][x] != fa[i][y]) x = fa[i][x], y = fa[i][y];
return fa[0][x];
}
void update(int u, int len) {
if (len == 0) return ;
int x = lg[len], tmp = c[u] + n;
Add_edge(head2, tmp, st[x][u]);
if ((len -= 1 << x) == 0) return ;
for (int i = 0; (1 << i) <= len; ++i) if ((len >> i) & 1) u = fa[i][u];
Add_edge(head2, tmp, st[x][u]);
}
void tarjan(int u) {
dfn[u] = low[u] = ++num, stk[++top] = u;
for (int i = head2[u]; i; i = edge[i].nxt) {
int v = edge[i].v;
if (!dfn[v])
tarjan(v), low[u] = min(low[u], low[v]);
else
if (!co[v]) low[u] = min(low[u], dfn[v]);
}
if (low[u] == dfn[u]) {
++col;
do scc[co[stk[top]] = col].push_back(stk[top]);
while (stk[top--] != u);
}
}
int main() {
#ifndef ONLINE_JUDGE
freopen("cpp.in", "r", stdin), freopen("cpp.out", "w", stdout);
#endif
read(n), read(k), cnt = n + k;
for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
for (int u, v, i = 1; i < n; ++i)
read(u), read(v), Add_edge(head1, u, v), Add_edge(head1, v, u);
for (int i = 1; i <= n; ++i) read(c[i]), vec[c[i]].push_back(i), Add_edge(head2, i, c[i] + n);
dfs(1, 0);
for (int i = 1; i <= k; ++i) {
int l = vec[i][0], r = vec[i][0];
for (int j = 1; j < vec[i].size(); ++j) {
if (pos[vec[i][j]] < pos[l]) l = vec[i][j];
if (pos[vec[i][j]] > pos[r]) r = vec[i][j];
}
int lca = LCA(l, r);
for (int j = 0; j < vec[i].size(); ++j)
update(vec[i][j], dep[vec[i][j]] - dep[lca]);
}
for (int i = 1; i <= cnt; ++i) if (!dfn[i]) tarjan(i);
for (int i = 1; i <= col; ++i)
for (int j = 0; j < scc[i].size(); ++j) if (scc[i][j] > n && scc[i][j] <= n + k) ++val[i];
int ans = n;
for (int i = 1; i <= col; ++i)
for (int j = 0; j < scc[i].size(); ++j)
for (int k = head2[scc[i][j]]; k; k = edge[k].nxt)
if (co[edge[k].v] != i) ++dgr[i];
for (int i = 1; i <= col; ++i) if (dgr[i] == 0) ans = min(ans, val[i] - 1);
printf("%d
", ans);
return 0;
}