https://codeforces.com/contest/375/problem/D
这道题也有用树上莫队的解法,这里再给一个树上启发式合并的解法。
树上启发式合并,类似轻重链剖分,先算出每个节点的重儿子,然后计算答案时先递归计算轻儿子的答案,标记clr为true,然后计算重儿子的答案,clr为false,然后把轻儿子的节点暴力插到重儿子里,最后把父节点自己加入。
int n;
vector<int> G[MAXN];
int siz[MAXN], mch[MAXN];
void dfs1(int u, int p) {
siz[u] = 1, mch[u] = 0;
for (int &v : G[u]) {
if (v == p)
continue;
dfs1(v, u);
siz[u] += siz[v];
if (siz[mch[u]] < siz[v])
mch[u] = v;
}
}
由于树上启发式合并并不关心深度,所以没有必要维护深度。
void calc(int u, int p, int skip, int d) {
bit.Add(cnt[c[u]], -1);
cnt[c[u]] += d;
bit.Add(cnt[c[u]], 1);
for (int v : G[u]) {
if (v == p || v == skip)
continue;
calc(v, u, 0, d);
}
}
void dfs2(int u, int p, bool keep) {
for (int &v : G[u]) {
if (v == p || v == mch[u])
continue;
dfs2(v, u, false);
}
if (mch[u])
dfs2(mch[u], u, true);
calc(u, p, mch[u], 1);
for (pii &q : Q[u]) {
int id = q.first, k = q.second;
ans[q.first] = bit.Sum(k, n);
}
if (!keep)
calc(u, p, 0, -1);
}
然后是主要的计算过程dfs2,dfs2优先进入所有的轻儿子,并且不keep轻儿子的答案,保持树状数组为空。然后进入重儿子计算并keep重儿子的结果。这里使用一个辅助函数calc,calc的修改值为1时表示向树状数组中添加,然后命令其在添加时skip掉重儿子。计算完毕后树状数组中存着这棵子树对应的状态,然后取出所有的询问进行回答。那之后,假如不keep树状数组,调用calc修改值为-1,并且不跳过重儿子,把整棵子树删除干净。
时间复杂度为 (O(nlog^2n))
https://codeforces.com/gym/102832/problem/F
这里的查询要去重,所以要先计算再查询。而且要注意cache的命中。一次树遍历就统计出所有的信息,把常用的局部值放在数组的低维。
int n, k;
int a[MAXN];
vector<int> G[MAXN];
int siz[MAXN], mch[MAXN];
void dfs1(int u, int p) {
siz[u] = 1, mch[u] = 0;
for (int &v : G[u]) {
if (v == p)
continue;
dfs1(v, u);
siz[u] += siz[v];
if (siz[mch[u]] < siz[v])
mch[u] = v;
}
}
int cnt[1 << 20][17][2];
ll ans;
void calc1(int u, int p, int LCA) {
int val = a[u] ^ a[LCA];
for (int k = 16; k >= 0; --k) {
int uk = (u >> k) & 1;
ans += (1LL << k) * cnt[val][k][uk ^ 1];
}
for (int &v : G[u]) {
if (v == p)
continue;
calc1(v, u, LCA);
}
}
void calc2(int u, int p) {
int val = a[u];
for (int k = 16; k >= 0; --k) {
int uk = (u >> k) & 1;
++cnt[val][k][uk];
}
for (int &v : G[u]) {
if (v == p)
continue;
calc2(v, u);
}
}
void calc3(int u, int p) {
int val = a[u];
memset(cnt[val], 0, sizeof(cnt[val]));
for (int &v : G[u]) {
if (v == p)
continue;
calc3(v, u);
}
}
void dfs2(int u, int p, bool keep) {
for (int &v : G[u]) {
if (v == p || v == mch[u])
continue;
dfs2(v, u, false);
}
if (mch[u])
dfs2(mch[u], u, true);
int val = a[u];
for (int k = 16; k >= 0; --k) {
int uk = (u >> k) & 1;
ans += (1LL << k) * cnt[0][k][uk ^ 1];
++cnt[val][k][uk];
}
for (int &v : G[u]) {
if (v == p || v == mch[u])
continue;
calc1(v, u, u);
calc2(v, u);
}
if (!keep) {
memset(cnt[val], 0, sizeof(cnt[val]));
for (int &v : G[u]) {
if (v == p)
continue;
calc3(v, u);
}
}
}
void solve() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", &a[i]);
for (int i = 1; i <= n; ++i)
G[i].clear();
for (int i = 1; i <= n - 1; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1, 0);
dfs2(1, 0, true);
printf("%lld
", ans);
}