题目描述
给定一棵树,每一个点带权 (v_i)。
每一个点的答案为其子树内节点"点权加距离"的异或和。
求所有点答案的和。
做法
把每一位拆开来考虑贡献。
考虑每一个点对其祖先的贡献,发现产生贡献的点是若干段深度连续的区间。
举个例子:
节点点 (u) 深度为 (16),权值为 (0),考虑其第二进制位上第二位((2^2 = 4))对其祖先的贡献。
只有深度在 ([1,4] igcup [9,12]) 的节点 (v) 会受到节点 (u) 的影响。
形式化来讲,只有深度 (d) 在模 (2^3 = 8) 意义下属于 $ [1, 4]$ 的节点 (v) 会受到节点 (u) 的影响。
上面举的是一个特殊的例子,来个更加形式化的表述吧。
考虑第 (b) 位,对于一个点 (u),它的深度为 (d_u),思考它哪些些深度 (d_v) 会产生贡献。
[2^b le v_u + d_u - d_v < 2^{b + 1}
]
解个不等式就能知道深度 (d_v) 的区间了。
发现这个性质就好办了,分位考虑后模意义下修改的是一段区间,用树状数组做。
要求的点在子树内,用树差分做。
总复杂度 (O(n log^2 n))。
洛谷评测机太慢需要开 O2
才能通过。
后记
开考前看见考试包里有个题叫 tree
,我就觉得我要翻盘了。
开场搞这题,后面确实切掉了。
但是第一题这简单题怎么就我没切啊。。。
退役了退役了。
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 525015;
int n;
int fa[N], val[N];
int head[N], nex[N], to[N], ecnt;
int full[21], is[N][21];
inline int read() {
int x = 0; char ch = getchar();
while(!isdigit(ch)) ch = getchar();
while(isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
return x;
}
inline void addE(int u, int v) {
to[++ecnt] = v;
nex[ecnt] = head[u], head[u] = ecnt;
}
struct BIT {
#define lowbit(x) (x & -x)
int upp;
vector<int> c;
void init(int upLim) {
upp = upLim;
c.resize(upLim);
}
void update(int p) {
if(!p) return c[0] ^= 1, void();
for(int i = p; i < upp; i += lowbit(i))
c[i] ^= 1;
}
int xorSum(int p) {
int res = c[0];
for(int i = p; i; i -= lowbit(i))
res ^= c[i];
return res;
}
#undef lowbit
} tr[21];
inline void update(int b, int x, int d) {
int l = ((1 << b) - x + d) & full[b], r = ((1 << (b + 1)) - x + d) & full[b];
if(l < r) {
tr[b].update(l), tr[b].update(r);
} else {
tr[b].update(0), tr[b].update(r), tr[b].update(l);
}
}
inline int calc(int b, int d) { return tr[b].xorSum(d & full[b]); }
long long ans;
void dfs(int u, int dep) {
for(int b = 0; b <= 20; ++b)
is[u][b] ^= calc(b, dep);
for(int i = head[u]; i; i = nex[i])
dfs(to[i], dep - 1);
int res = 0;
for(int b = 0; b <= 20; ++b) {
update(b, val[u], dep);
is[u][b] ^= calc(b, dep);
if(is[u][b]) res += 1 << b;
ans += is[u][b] ? 1 << b : 0;
}
}
int main() {
n = read();
for(int i = 1; i <= n; ++i)
val[i] = read();
for(int i = 2; i <= n; ++i) {
fa[i] = read();
addE(fa[i], i);
}
for(int b = 0; b <= 20; ++b)
full[b] = (1 << (b + 1)) - 1;
for(int b = 0; b <= 20; ++b)
tr[b].init(1 << (b + 1));
dfs(1, n);
printf("%lld
", ans);
return 0;
}