题意
给定一个 (n) 个点的带权树 (边权为正整数), 求树中距离小于等于 (k) 的点对数量.
(1 le n le 4 imes10^4, k le 2 imes10^4)
思路
点分治.
对树上的点到当前根节点的距离建一个桶, 并在这个桶上建树状数组.
统计点对数量时在树状数组上查找 (k-dis[u]) 的前缀和即可.
时间复杂度 (O(nlog^2 n)).
其实不用树状数组也可以, 因为边权为正, 所以越往下, (dis) 越大,
那么只需要维护一个指针 (t=k-dis[u]), 指针往左移动的时候, 减去指针所指位置的节点数量, 剩下的就是满足 (dis[v]+dis[u] le k) 的点.
时间复杂度 $O(nlog n) $.
但是在洛谷上跑出来反而更慢了....
代码
(O(nlog^2 n))
#include<bits/stdc++.h>
using namespace std;
const int _=2e4+7;
const int __=4e4+7;
const int ___=8e4+7;
const int inf=0x3f3f3f3f;
int n,k,dis[__],sz[__],rt,minx=inf,ans,q[__],top;
int lst[__],nxt[___],to[___],len[___],tot;
int c[_];
bool vis[__];
void add(int x,int y,int w){ nxt[++tot]=lst[x]; to[tot]=y; len[tot]=w; lst[x]=tot; }
void pre(int u,int fa){
sz[u]=1;
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa||vis[v]) continue;
pre(v,u);
sz[u]+=sz[v];
}
}
void g_rt(int u,int fa,int sum){
int maxn=sum-sz[u];
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa||vis[v]) continue;
g_rt(v,u,sum);
maxn=max(maxn,sz[v]);
}
if(maxn<minx){ minx=maxn; rt=u; }
}
void modify(int x,int v){
for(int i=x;i<=k;i+=i&(-i)){
c[i]+=v;
}
}
int query(int x){
int res=0;
for(int i=x;i;i-=i&(-i))
res+=c[i];
return res;
}
void cnt(int u,int fa){
if(dis[u]>k) return;
ans+=query(k-dis[u])+1;
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+len[i];
cnt(v,u);
}
}
void mrk(int u,int fa){
if(dis[u]>k) return;
modify(dis[u],1);
q[++top]=dis[u];
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa||vis[v]) continue;
mrk(v,u);
}
}
void calc(int u){
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]) continue;
dis[v]=len[i];
cnt(v,0);
mrk(v,0);
}
for(int i=1;i<=top;i++) modify(q[i],-1);
top=0;
}
void run(int u){
pre(u,0);
minx=inf;
g_rt(u,0,sz[u]);
u=rt;
vis[u]=1;
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]) continue;
run(v);
}
calc(u);
vis[u]=0;
}
int main(){
//freopen("x.in","r",stdin);
//freopen("x.out","w",stdout);
cin>>n; int x,y,w;
for(int i=1;i<n;i++){
scanf("%d%d%d",&x,&y,&w);
add(x,y,w);
add(y,x,w);
}
cin>>k;
run(1);
printf("%d
",ans); // 一对点只会计算到一次, 所以不用 /2
return 0;
}
(O(nlog n))
#include<bits/stdc++.h>
using namespace std;
const int _=2e4+7;
const int __=4e4+7;
const int ___=8e4+7;
const int inf=0x3f3f3f3f;
int n,k,dis[__],sz[__],rt,minx=inf,ans,q[__],top,all;
int lst[__],nxt[___],to[___],len[___],tot;
int c[_];
bool vis[__];
void add(int x,int y,int w){ nxt[++tot]=lst[x]; to[tot]=y; len[tot]=w; lst[x]=tot; }
void pre(int u,int fa){
sz[u]=1;
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa||vis[v]) continue;
pre(v,u);
sz[u]+=sz[v];
}
}
void g_rt(int u,int fa,int sum){
int maxn=sum-sz[u];
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa||vis[v]) continue;
g_rt(v,u,sum);
maxn=max(maxn,sz[v]);
}
if(maxn<minx){ minx=maxn; rt=u; }
}
void cnt(int u,int fa,int t,int res){
if(dis[u]>k) return;
while(t>k-dis[u]) res-=c[t--];
ans+=res;
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+len[i];
cnt(v,u,t,res);
}
}
void mrk(int u,int fa){
if(dis[u]>k) return;
c[dis[u]]++; all++;
q[++top]=dis[u];
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(v==fa||vis[v]) continue;
mrk(v,u);
}
}
void calc(int u){
all=1;
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]) continue;
dis[v]=len[i];
cnt(v,0,k,all);
mrk(v,0);
}
for(int i=1;i<=top;i++) c[q[i]]=0;
top=all=0;
}
void run(int u){
pre(u,0);
minx=inf;
g_rt(u,0,sz[u]);
u=rt;
vis[u]=1;
for(int i=lst[u];i;i=nxt[i]){
int v=to[i];
if(vis[v]) continue;
run(v);
}
calc(u);
vis[u]=0;
}
int main(){
//freopen("x.in","r",stdin);
//freopen("x.out","w",stdout);
cin>>n; int x,y,w;
for(int i=1;i<n;i++){
scanf("%d%d%d",&x,&y,&w);
add(x,y,w);
add(y,x,w);
}
cin>>k;
run(1);
printf("%d
",ans); // 一对点只会计算到一次, 所以不用 /2
return 0;
}