hdu4757
题意
给出一棵树,每个节点有权值,每次查询节点 ((u, v)) 以及 (x) ,问 (u) 到 (v) 路径上的某个节点与 (x) 异或最大的值是多少。
分析
Trie 的新姿势!
如果直接问 (x) 与某个区间中哪个数异或后最大,那么直接把区间所有数转化成二进制数(从高到低位,保证长度相同,前面不足补 (0) )插入到字典树中,然后查询将每一位都取反的 (x) ,比如 (x) 为 (0101) ,那么我们去查询 (1010) ,如果存在对应的位 (i),加上 (1 << i) 即可。
但是现在有多个查询,怎么优化呢?
对于每一节点都建立一颗 (Trie) 树,用一个数组 (sz) 记录前缀的数量,如果 (v) 是 (u) 的子节点,比如 (v) 的权值是 (010) ,(u) 的权值是 (011) ,假设 (u) 已经加入,那么在加入 (v) 的时候,发现 (01) 这个前缀数量已经为 (1) 了,正好 (v) 也有一个这样的前缀,所以 (sz[now]+=1) ,也就是说继承了父节点那个 (Trie) 对应的前缀的数量,且加上自己的一个。到 (010) 时,发现存在 (011) 这个前缀,除了新增 (010) 这个前缀,对于 (011) 我们直接指向父节点的这个对应的值。也就是说我们只更新我们新加入的数对应的那些前缀的数量,其它的全部和父亲节点的 (Trie) 树保持一致。也就是每个节点记录的前缀的数量是从当前节点到根节点的所有相应前缀的数量。
再回到上面的做法,当我们去查询每一位都取反的 (x) 时,我们其实只关心是否某一位有对应的存在。
如果 (x) 为 (1001) ,取反后为 (0110) ,那么我们其实是想知道 (u) 到 (v) 这条路径是否有某个前缀为 (0) 的数,设 (k = LCA(u, v)) , 设 (F()) 为到根节点前缀为 (0) 的数量,那么计算判断有没有只要 (F(u) + F(v) - 2 * F(k) > 0) 即可。(如果这一步取到前缀 (0) 了,后面去查找 (01) 这个前缀是否存在,如果没有取到,那么取的只能是 (1) ,那么就去查找 (11) 是否存在)
还有注意这样算 (k) 这个节点的权值 (a[k]) 是没有算过的,所以最后最后要和 (a[k] xor x) 比较一下取得最大值。
code
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 2e6 + 10;
const int LOGN = 20;
int n, m;
int a[MAXN];
vector<int> G[MAXN];
int dep[MAXN];
int p[MAXN][LOGN];
void init() {
for(int i = 1; i < LOGN; i++) {
for(int j = 1; j <= n; j++) {
p[j][i] = p[p[j][i - 1]][i - 1];
}
}
}
int lca(int u, int v) {
if(dep[u] > dep[v]) swap(u, v);
for(int i = 0; i < LOGN; i++) {
if((dep[v] - dep[u]) >> i & 1) {
v = p[v][i];
}
}
if(v == u) return u;
for(int i = LOGN - 1; i >= 0; i--) {
if(p[u][i] != p[v][i]) {
u = p[u][i];
v = p[v][i];
}
}
return p[u][0];
}
int sz[MAXN];
int nxt[MAXN][2], root[MAXN], L;
int newnode() {
nxt[L][0] = nxt[L][1] = 0;
return L++;
}
void insert(int u, int fa, int x) {
int now1 = root[u], now2 = root[fa];
for(int i = 18; i >= 0; i--) {
int d = (x >> i) & 1;
nxt[now1][d] = newnode();
nxt[now1][!d] = nxt[now2][!d];
now1 = nxt[now1][d]; now2 = nxt[now2][d];
sz[now1] = sz[now2] + 1;
}
}
int query(int u, int v, int x) {
int k = lca(u, v);
int res = 0;
int now1 = root[u], now2 = root[v], now3 = root[k];
for(int i = 18; i >= 0; i--) {
int d = (x >> i) & 1;
if(nxt[now1][!d] + nxt[now2][!d] - 2 * nxt[now3][!d] > 0) {
res += (1 << i);
d = !d;
}
now1 = nxt[now1][d];
now2 = nxt[now2][d];
now3 = nxt[now3][d];
}
return max(res, x ^ a[k]);
}
void dfs(int fa, int u) {
root[u] = newnode();
insert(u, fa, a[u]);
p[u][0] = fa;
dep[u] = dep[fa] + 1;
for(int i = 0; i < (int)G[u].size(); i++) {
int v = G[u][i];
if(v != fa) {
dfs(u, v);
}
}
}
int main() {
while(~scanf("%d%d", &n, &m)) {
L = 1;
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
G[i].clear();
}
for(int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs(0, 1);
init();
while(m--) {
int x, y, z;
scanf("%d%d%d", &x, &y, &z);
printf("%d
", query(x, y, z));
}
}
return 0;
}