题意
给定一颗大小为(n)的点带权无根树,有(q)个询问,每次询问与结点(u)距离不超过(k)的结点的点权之和
(kleq 400,qleq 5000,nleq 10^6)
解法
询问与节点(u)距离不超过(k)的结点的点权,考虑这些点的来源,一是来源于它的子树内,二是来源于它的祖宗链
假设我们求得了以某个结点为根,对于节点(u),在其子树内距离其不超过(k)的的点的权值和(f_u[k])
那么我们就解决了这些点的第一个来源
接下来对其祖宗链上的答案进行统计:可以利用一个小容斥,对于结点(u)的(t)位祖先(v),它对答案的贡献是(f_v[k-t]-f_{son_v}[k-t-1]),其中(son_v)指的是(v)在这条链上的儿子
考虑如何求(f)数组
由于只有询问没有修改,我们考虑离线处理,按照(dfs)序排序
把整颗树抽象称为一个二维平面,其中横轴为(dfs)序,纵轴为深度
由于一个点的子树内,深度是连续的,(dfs)序也是连续的,所以每次查询相当于查询一个矩形内点的权值和
我们把这个矩形的左右边界拆开,差分一下即可求出这个矩形内点的权值
离线后扫一遍算出贡献即可
代码
#include <cstdio>
#include <vector>
#include <cstring>
using namespace std;
const int N = 2e6 + 10;
struct node { int x, v, id; };
int n, q;
int cnt;
int fa[N], mp[N], dep[N], dfn[N], sz[N];
int cap;
int head[N], to[N << 1], nxt[N << 1];
long long p[N], ans[N];
vector<node> g[N];
struct BIT {
long long c[N];
BIT() { memset(c, 0, sizeof c); }
void insert(int x, long long v) {
for (; x && x <= n; x += x & -x) c[x] += v;
}
long long query(int x) {
x = min(x, n);
long long res = 0;
for (; x; x -= x & -x) res += c[x];
return res;
}
} bit;
template<typename _T> void read(_T& x) {
int c = getchar(); x = 0;
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
}
inline void add(int x, int y) { to[++cap] = y, nxt[cap] = head[x], head[x] = cap; }
void DFS(int x) {
dfn[x] = ++cnt, mp[cnt] = x, sz[x] = 1;
for (int i = head[x]; i; i = nxt[i])
if (to[i] != fa[x])
dep[to[i]] = dep[x] + 1, fa[to[i]] = x, DFS(to[i]), sz[x] += sz[to[i]];
}
int main() {
read(n);
for (int i = 1; i <= n; ++i) read(p[i]);
int u, v;
for (int i = 1; i < n; ++i) {
read(u), read(v);
add(u, v), add(v, u);
}
dep[1] = 1;
DFS(1);
read(q);
int x, k;
for (int i = 1; i <= q; ++i) {
read(x), read(k);
int p = x, lst = 0;
while (p && k >= 0) {
int pos = dfn[p];
g[pos - 1].push_back((node){dep[p] + k, -1, i});
g[pos + sz[p] - 1].push_back((node){dep[p] + k, 1, i});
if ((p ^ x) && k) {
pos = dfn[lst];
g[pos - 1].push_back((node){dep[lst] + k - 1, 1, i});
g[pos + sz[lst] - 1].push_back((node){dep[lst] + k - 1, -1, i});
}
lst = p, k--;
p = fa[p];
}
}
for (int i = 1; i <= n; ++i) {
bit.insert(dep[mp[i]], p[mp[i]]);
for (int j = 0; j < g[i].size(); ++j)
ans[g[i][j].id] += 1LL * g[i][j].v * bit.query(g[i][j].x);
}
for (int i = 1; i <= q; ++i) printf("%lld
", ans[i]);
return 0;
}