首先,这是一道坑题,我拍了几百组数据都是对的,交上去就WA,原因下面会讲。。。
一开始我觉得要链剖,后来ZYH说。。。只要dfs序就可以解题。
然后,解法嘛。。。就是每个点到根的链都建成一棵线段树,然后发现会MLE,于是就可持久化了所有线段树。
在查询的时候呢,先找出两个点a, b的LCA,不妨叫c,然后找c的父亲叫d,每次比较k和seg[a] + seg[b] - seg[c] - seg[d]的大小就可以了。
1 /************************************************************** 2 Problem: 2588 3 User: rausen 4 Language: C++ 5 Result: Accepted 6 Time:4396 ms 7 Memory:51788 kb 8 ****************************************************************/ 9 10 #include <cstdio> 11 #include <cstdlib> 12 #include <cstring> 13 #include <algorithm> 14 15 using namespace std; 16 17 struct tree_node{ 18 int pos, dep, v, fa[20]; 19 } tr[150000]; 20 21 struct edges{ 22 int next, to; 23 }e[250000]; 24 25 struct segment{ 26 int lson, rson, sum; 27 } seg[2500000]; 28 29 int n, m, tot, TOT, cnt, sz, ans; 30 int X, Y, Z, K; 31 int first[150000], V[150000], N[150000], root[150000], num[150000]; 32 33 void add_edge(int x, int y){ 34 e[++TOT].next = first[x]; 35 first[x] = TOT; 36 e[TOT].to = y; 37 } 38 39 void add_Edges(int x, int y){ 40 add_edge(x, y); 41 add_edge(y, x); 42 } 43 44 int find(int x){ 45 int l = 1, r = tot; 46 while (l < r){ 47 int mid = (l + r) >> 1; 48 if (N[mid] < x) l = mid + 1; 49 else r = mid; 50 } 51 return l; 52 } 53 54 void dfs(int p){ 55 num[++cnt] = p, tr[p].pos = cnt; 56 int x, y; 57 for (x = 1; x <= 16; ++x) 58 if ((1 << x) < tr[p].dep) 59 tr[p].fa[x] = tr[tr[p].fa[x - 1]].fa[x - 1]; 60 else break; 61 for (x = first[p]; x; x = e[x].next){ 62 y = e[x].to; 63 if (tr[p].fa[0] != y){ 64 tr[y].fa[0] = p, tr[y].dep = tr[p].dep + 1; 65 dfs(y); 66 } 67 } 68 } 69 70 void add(int l, int r, int x, int &y, int num){ 71 y = ++sz, seg[y].sum = seg[x].sum + 1; 72 seg[y].lson = seg[x].lson, seg[y].rson = seg[x].rson; 73 if (l == r) return; 74 int mid = (l + r) >> 1; 75 if (num <= mid) 76 add(l, mid, seg[x].lson, seg[y].lson, num); 77 else add(mid + 1, r, seg[x].rson, seg[y].rson, num); 78 } 79 80 int LCA(int x, int y){ 81 if (tr[x].dep < tr[y].dep) swap(x, y); 82 int tmp = tr[x].dep - tr[y].dep; 83 for (int i = 0; i <= 16; ++i) 84 if (tmp & (1 << i)) x = tr[x].fa[i]; 85 for (int i = 16; i >= 0; --i) 86 if (tr[x].fa[i] != tr[y].fa[i]) 87 x = tr[x].fa[i], y = tr[y].fa[i]; 88 if (x == y) return x; 89 else return tr[x].fa[0]; 90 } 91 92 int query(int x, int y, int K){ 93 int a = x, b = y, c = LCA(x, y), d = tr[c].fa[0]; 94 a = root[tr[a].pos], b = root[tr[b].pos], c = root[tr[c].pos], d = root[tr[d].pos]; 95 int l = 1, r = tot; 96 while (l < r){ 97 int mid = (l + r) >> 1; 98 int tmp = seg[seg[a].lson].sum + seg[seg[b].lson].sum - seg[seg[c].lson].sum - seg[seg[d].lson].sum; 99 if (tmp >= K) 100 r = mid, a = seg[a].lson, b = seg[b].lson, c = seg[c].lson, d = seg[d].lson; 101 else 102 K -= tmp, l = mid + 1, a = seg[a].rson, b = seg[b].rson, c = seg[c].rson, d = seg[d].rson; 103 } 104 return N[l]; 105 } 106 107 int main(){ 108 scanf("%d%d", &n, &m); 109 for (int i = 1; i <= n; ++i){ 110 scanf("%d", &tr[i].v); 111 V[i] = tr[i].v; 112 } 113 sort(V + 1, V + n + 1); 114 N[++tot] = V[1]; 115 for (int i = 2; i <= n; ++i) 116 if (V[i] != V[i - 1]) 117 N[++tot] = V[i]; 118 for (int i = 1; i <= n; ++i) 119 tr[i].v = find(tr[i].v); 120 for (int i = 1; i < n; ++i){ 121 scanf("%d%d", &X, &Y); 122 add_Edges(X, Y); 123 } 124 cnt = 0; 125 tr[1].fa[0] = 0, tr[1].dep = 1; 126 dfs(1); 127 128 root[0] = 0, seg[0].sum = seg[0].lson = seg[0].rson = 0; 129 for (int i = 1; i <= n; ++i){ 130 int t = num[i]; 131 add(1, tot, root[tr[tr[t].fa[0]].pos], root[i], tr[t].v); 132 } 133 while (m--){ 134 scanf("%d%d%d", &X, &Y, &K); 135 X ^= ans; 136 ans = query(X ,Y ,K); 137 printf("%d", ans); 138 if (m) printf(" "); 139 } 140 return 0; 141 }
(WA的原因:太坑爹了,我把根的深度设成0,然后倍增乱搞的时候RE了。。。)