(Solution)
我们假设(1)为根。
考虑两种问题,对于每个节点(x):
一种是子树覆盖所有颜色,那答案就是(x)向上走的最远距离。
另一种是子树以外的所有节点(包含该节点)覆盖所有颜色,那答案就是(x)往下走的最远距离。
两个最远距离都可以预处理得到。
考虑第一种情况:我们发现对于每种颜色(coli),颜色为它的节点到根的节点子树都包含这种颜色了。
那么这种颜色的贡献应该是所有该颜色的点到根路径的并,发现可能有重。
我们将这些节点按照(dfn)序排序,发现相邻节点的(lca)是最近的,这样只需要再减去相邻节点的(lca)到根的路径就可以不重不漏地实现修改操作。
然后直接差分即可。
第二种情况:对于每种颜色(coli),颜色为它的节点的总(LCA)即其祖先的子树是完全包含这个颜色的,那么这些节点的反树就不会包含这种颜色,其余的可以。
那么直接将(LCA)打上标记然后搜一遍标记从下向上传即可。
于是这道题就解决了。(考场想的比这个复杂得多。。。)
(Code)
#include <cstdio>
#include <vector>
#include <algorithm>
#define N 1000010
#define ll long long
#define ls (x << 1)
#define rs (x << 1 | 1)
#define mem(x, a) memset(x, a, sizeof x)
#define mpy(x, y) memcpy(x, y, sizeof y)
#define fo(x, a, b) for (int x = (a); x <= (b); x++)
#define fd(x, a, b) for (int x = (a); x >= (b); x--)
#define go(x) for (int p = tail[x], v; p; p = e[p].fr)
using namespace std;
struct node{int v, fr;}e[N << 1];
int n, m, col[N], z[N], top = 0, tail[N], cnt = 0;
int fa[N][21], dep[N], mx[N][2], son[N], mx_up[N];
int dfn[N], st[N], ed[N], ton[N], ans = 0;
int b[N], tag[N];
vector<int> eg[N / 10];
bool label[N];
inline int read() {
int x = 0, f = 0; char c = getchar();
while (c < '0' || c > '9') f = (c == '-') ? 1 : f, c = getchar();
while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return f ? -x : x;
}
inline void add(int u, int v) {e[++cnt] = (node){v, tail[u]}; tail[u] = cnt;}
int LCA(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
for (int i = 0, cha = dep[x] - dep[y]; cha; cha >>= 1, i++)
if (cha & 1) x = fa[x][i];
if (x == y) return x;
fd(i, 19, 0) if (fa[x][i] != fa[y][i])
x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
void well_prepare() {
z[++top] = 1, dep[1] = 1;
while (top) {
int x = z[top];
if (label[x]) {
ed[x] = dfn[0];
if (mx[x][0] == 0) mx[x][0] = 1, son[x] = x;
if (mx[x][0] + 1 > mx[fa[x][0]][0]) {
mx[fa[x][0]][1] = mx[fa[x][0]][0];
mx[fa[x][0]][0] = mx[x][0] + 1, son[fa[x][0]] = x;
}
else if (mx[x][0] + 1 > mx[fa[x][0]][1])
mx[fa[x][0]][1] = mx[x][0] + 1;
top--; continue;
}
dfn[++dfn[0]] = x; st[x] = dfn[0];
for (int i = 0; fa[fa[x][i]][i]; i++)
fa[x][i + 1] = fa[fa[x][i]][i];
go(x) {
if ((v = e[p].v) == fa[x][0]) continue;
fa[v][0] = x, dep[v] = dep[x] + 1, z[++top] = v;
}
label[x] = 1;
}
mx_up[1] = 1;
fo(i, 1, n) {
int x = dfn[i];
go(x) {
if ((v = e[p].v) == fa[x][0]) continue;
mx_up[v] = mx_up[x] + 1;
if (son[x] == v) mx_up[v] = max(mx_up[v], mx[x][1] + 1);
else mx_up[v] = max(mx_up[v], mx[x][0] + 1);
}
}
}
inline bool cmp(int x, int y) {return st[x] < st[y];}
int main()
{
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
n = read(), m = read();
fo(i, 1, n) {
col[i] = read();
eg[col[i]].push_back(i);
}
fo(i, 2, n) {
int u = read(), v = read();
add(u, v), add(v, u);
}
well_prepare();
fo(i, 1, m) {
int lca_ = eg[i][0], len = eg[i].size() - 1;
fo(j, 1, len) lca_ = LCA(lca_, eg[i][j]);
ton[lca_]++;
}
fd(i, n, 1) ton[fa[dfn[i]][0]] += ton[dfn[i]];
fo(i, 1, n) if (! ton[i])
ans = max(ans, mx[i][0] + 1);
fo(i, 1, m) {
int len = eg[i].size();
fo(j, 0, len - 1) b[j + 1] = eg[i][j];
sort(b + 1, b + len + 1, cmp);
fo(j, 1, len) tag[b[j]]++;
fo(j, 2, len) {
int lca = LCA(b[j - 1], b[j]);
tag[lca]--;
}
}
fo(i, 1, n) label[i] = 0;
z[top = 1] = 1;
while (top) {
int x = z[top];
if (label[x]) {
tag[fa[x][0]] += tag[x];
if (tag[x] == m) ans = max(ans, mx_up[x]);
top--; continue;
}
go(x) {
if ((v = e[p].v) == fa[x][0]) continue;
z[++top] = v;
}
label[x] = 1;
}
printf("%d
", ans);
return 0;
}