(color{#0066ff}{题目描述})
给你一棵TREE,以及这棵树上边的距离.问有多少对点它们两者间的距离小于等于K
(color{#0066ff}{输入格式})
N(n<=40000) 接下来n-1行边描述管道,按照题目中写的输入 接下来是k
(color{#0066ff}{输出格式})
一行,有多少对点之间的距离小于等于k
(color{#0066ff}{输入样例})
7
1 6 13
6 3 9
3 5 7
4 1 3
2 4 20
4 7 2
10
(color{#0066ff}{输出样例})
5
(color{#0066ff}{数据范围与提示})
none
(color{#0066ff}{题解})
点分治
拆掉重心,对子树操作
维护一个权值树状数组
搜完一棵子树,把所有的dis放到一个数组里
for在树状数组上统计答案(注意要+1,那是当前点与重心的点对贡献)
之后把当前子树加入树状数组,最后清空
#include<bits/stdc++.h>
using namespace std;
#define LL long long
LL in() {
char ch; int x = 0, f = 1;
while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
return x * f;
}
struct node {
int to, dis;
node *nxt;
node(int to = 0, int dis = 0, node *nxt = NULL): to(to), dis(dis), nxt(nxt) {}
void *operator new (size_t) {
static node *S = NULL, *T = NULL;
return (S == T) && (T = (S = new node[1024]) + 1024), S++;
}
};
const int maxn = 40505;
int siz[maxn], f[maxn], tmp[maxn], ls[maxn];
node *head[maxn];
bool vis[maxn];
int n, k, num, cnt, sum, root;
int ans;
struct BIT {
protected:
int st[maxn];
int low(int x) { return x & (-x); }
public:
void add(int pos, int x) { while(pos <= k) st[pos] += x, pos += low(pos); }
int query(int pos) { int re = 0; while(pos) re += st[pos], pos -= low(pos); return re; }
}b;
void add(int from, int to, int dis) {
head[from] = new node(to, dis, head[from]);
}
void getdis(int x, int fa, int dis) {
ls[++cnt] = dis;
for(node *i = head[x]; i; i = i->nxt)
if(i->to != fa && !vis[i->to])
getdis(i->to, x, dis + i->dis);
}
void getroot(int x, int fa) {
f[x] = 0, siz[x] = 1;
for(node *i = head[x]; i; i = i->nxt) {
if(vis[i->to] || i->to == fa) continue;
getroot(i->to, x);
siz[x] += siz[i->to];
f[x] = std::max(f[x], siz[i->to]);
}
f[x] = std::max(f[x], sum - siz[x]);
if(f[x] < f[root]) root = x;
}
void calc(int x) {
num = 0;
for(node *i = head[x]; i; i = i->nxt) {
if(vis[i->to]) continue;
cnt = 0;
getdis(i->to, 0, i->dis);
for(int j = 1; j <= cnt; j++) if(ls[j] <= k) ans += b.query(k - ls[j]) + 1;
for(int j = 1; j <= cnt; j++) if(ls[j] <= k) b.add(ls[j], 1), tmp[++num] = ls[j];
}
for(int i = 1; i <= num; i++) b.add(tmp[i], -1);
}
void work(int x) {
vis[x] = true;
calc(x);
for(node *i = head[x]; i; i = i->nxt) {
if(vis[i->to]) continue;
root = 0, sum = siz[i->to];
getroot(i->to, 0);
work(root);
}
}
int main() {
n = in();
int x, y, z;
for(int i = 1; i < n; i++) {
x = in(), y = in(), z = in();
add(x, y, z), add(y, x, z);
}
k = in();
f[0] = sum = n;
getroot(1, 0);
work(root);
printf("%d", ans);
return 0;
}