BZOJ_1776
本来以为是个树的分治的题目,结果balabala敲完之后跑了8s多,大概是N*logN*logN的复杂度……比较好的办法还是转化成LCA问题,具体思路可以参考这篇博客:http://hi.baidu.com/edward_mj/item/2b46d4330c23edc61b9696c2。
#include<stdio.h> #include<string.h> #include<vector> #include<algorithm> #define MAXD 200010 #define MAXM 400010 #define INF 0x3f3f3f3f int N, K; int first[MAXD], e, next[MAXM], v[MAXM], col[MAXD], T; void add(int x, int y) { v[e] = y; next[e] = first[x], first[x] = e ++; } void init() { memset(first, -1, sizeof(first[0]) * (N + 1)), e = 0; for(int i = 1; i <= N; i ++) { int a, p; scanf("%d%d", &a, &p); col[i] = a; if(p == 0) T = i; else add(p, i), add(i, p); } } int ans[MAXD], size[MAXD], fa[MAXD], dep[MAXD], del[MAXD], q[MAXD]; struct St { int col, dep; St(){} St(int c, int d) : col(c), dep(d){} bool operator < (const St &t) const { if(col == t.col) return dep > t.dep; return col < t.col; } }; int findroot(int cur) { int root, min = INF, rear = 0; q[rear ++] = cur, fa[cur] = -1; for(int i = 0; i < rear; i ++) { int x = q[i]; for(int j = first[x]; j != -1; j = next[j]) if(!del[v[j]] && v[j] != fa[x]) q[rear ++] = v[j], fa[v[j]] = x; } for(int i = rear - 1; i >= 0; i --) { int x = q[i], max = 0; size[x] = 1; for(int j = first[x]; j != -1; j = next[j]) if(!del[v[j]] && v[j] != fa[x]) max = std::max(max, size[v[j]]), size[x] += size[v[j]]; max = std::max(max, rear - size[x]); if(max < min) min = max, root = x; } return root; } void refresh(int cur, std::vector<St> &p) { int rear = 0; std::vector<St> a; q[rear ++] = cur, fa[cur] = -1, dep[cur] = 1; for(int i = 0; i < rear; i ++) { int x = q[i]; a.push_back(St(col[x], dep[x])); for(int j = first[x]; j != -1; j = next[j]) if(!del[v[j]] && v[j] != fa[x]) q[rear ++] = v[j], fa[v[j]] = x, dep[v[j]] = dep[x] + 1; } std::sort(a.begin(), a.end()); for(int i = 0; i < a.size(); i ++) if(i == 0 || a[i].col != a[i - 1].col) p.push_back(a[i]); } void dfs(int cur) { int root = findroot(cur); std::vector<St> a; del[root] = 1; for(int i = first[root]; i != -1; i = next[i]) if(!del[v[i]]) dfs(v[i]), refresh(v[i], a); a.push_back(St(col[root], 0)); std::sort(a.begin(), a.end()); for(int i = 0; i + 1 < a.size(); i ++) if(i == 0 || a[i].col != a[i - 1].col) { if(a[i].col == a[i + 1].col) { int c = a[i].col; ans[c] = std::max(ans[c], a[i].dep + a[i + 1].dep); } } del[root] = 0; } void solve() { memset(ans, 0, sizeof(ans[0]) * (K + 1)); memset(del, 0, sizeof(del[0]) * (N + 1)); dfs(T); for(int i = 1; i <= K; i ++) printf("%d\n", ans[i]); } int main() { while(scanf("%d%d", &N, &K) == 2) { init(); solve(); } return 0; }