本文代码来源:https://blog.csdn.net/yang_7_46/article/details/9966455
本文参考论文来源:https://wenku.baidu.com/view/8861df38376baf1ffc4fada8.html?re=view
基于点分治的树分治算法
poj1741
基本步骤:1.一次dfs计算每个点的size
2.一次dfs找到重心
3.一次dfs计算每个点的深度
4.计算以root为重心情况下dis[u]+dis[v]<=k的点对数,u,v是root的儿子
5.遍历root的各个子树,对于每个子树
(1):减去dis[u]+dis[v]<=k的点对数,u,v是当前子树的儿子(为什么要减去?dis[u]+dis[v]=u,v的路径长度+2*dis[lca(u,v)],不是u,v的距离)
(2):开始分治
复杂度计算,分治复杂度O(logN),每次分治最大复杂度O(NlogN),总复杂度O(NlogN*logN)
#include<iostream> #include<cstring> #include<cstdio> #include<algorithm> #define MAXN 10010 using namespace std; int N,K; int ans,root,Max; struct node{ int v,next,w; }edge[MAXN<<1]; int head[MAXN],tot; int size[MAXN]; int maxv[MAXN]; int vis[MAXN]; int dis[MAXN]; int num; void init(){ tot=ans=0; memset(head,-1,sizeof head); memset(vis,0,sizeof vis); } void addedge(int u,int v,int w){ edge[tot].v=v; edge[tot].w=w; edge[tot].next=head[u]; head[u]=tot++; } //一次dfs处理子树的大小 void dfssize(int u,int f){ size[u]=1; maxv[u]=0; for(int i=head[u];i!=-1;i=edge[i].next){ int v=edge[i].v; if(v==f||vis[v]) continue; dfssize(v,u); size[u]+=size[v]; if(size[u]>maxv[u])maxv[u]=size[u]; } } //一次dfs找重心,这里的r不是重心 void dfsroot(int r,int u,int f){ if(size[r]-size[u]>maxv[u]) maxv[u]=size[r]-size[u]; if(maxv[u]<Max) Max=maxv[u],root=u; for(int i=head[u];i!=-1;i=edge[i].next){ int v=edge[i].v; if(v==f||vis[v]) continue; dfsroot(r,v,u); } } //一次dfs求每个点到重心的距离 void dfsdis(int u,int d,int f){ dis[num++]=d; for(int i=head[u];i!=-1;i=edge[i].next){ int v=edge[i].v; if(v!=f && !vis[v]) dfsdis(v,d+edge[i].w,u); } } int calc(int u,int d){ int ret=0; num=0; dfsdis(u,d,0);//求每个点到根的距离 sort(dis,dis+num); int i=0,j=num-1; while(i<j){ while(dis[i]+dis[j]>K && i<j) j--; ret+=j-i; i++; } return ret; } void dfs(int u){//注意这里的u并不是重心 Max=N; dfssize(u,0);//求每个子树的规模 dfsroot(u,u,0);//求重心 ans+=calc(root,0);//求以root为根,dis[u]+dis[v]的点对有多少 vis[root]=1;//把这个点从点集删掉 for(int i=head[root];i!=-1;i=edge[i].next){ int v=edge[i].v; if(!vis[v]){ ans-=calc(v,edge[i].w); dfs(v); } } } int main(){ while(scanf("%d%d",&N,&K)!=EOF){ if(!N) break; int u,v,w; init(); for(int i=1;i<N;i++){ scanf("%d%d%d",&u,&v,&w); addedge(u,v,w); addedge(v,u,w); } dfs(1); printf("%d ",ans); } return 0; }