点分治,是一种针对可带权树上简单路径统计问题的算法。
就 POJ 1741 来说:
问题:给一棵边带权树,问两点之间的距离小于等于$k$的点对有多少个。
解决:
当前有一个节点$u$,那么树上的路径可分为两种:(1) 经过节点$u$的 (2) 不经过节点$u$的
第 (2) 种路径,一定在$u$的某个子节点构成的子树中。在各个子树中找一点递归下去即可。
1 void solve(int u) { 2 ans+=calc(u,0); vis[u]=1; 3 for(int i=fro[u];i;i=nxt[i]) { 4 int v=to[i]; 5 if(vis[v]) continue; 6 ans-=calc(v,w[i]); 7 //合并路径时,u的同一个子树下的两点合并出的路径是不存在的。在此减去。 8 root=0,size=sz[v]; 9 findrt(v,0); solve(root); 10 } 11 }
找什么样的点递归下去使得效率最高?
递归层数要最少,所以应选一棵树中最大子树最小的点,即树的重心。
1 //size表示整棵树的大小 2 void findrt(int u,int fa) { 3 sz[u]=1; f[u]=0; 4 for(int i=fro[u];i;i=nxt[i]) { 5 int v=to[i]; 6 if(vis[v]||v==fa) continue; 7 findrt(v,u); 8 sz[u]+=sz[v]; f[u]=max(f[u],sz[v]); 9 } 10 f[u]=max(f[u],size-sz[u]); 11 if(f[u]<f[root]) root=u; 12 }
本题完整代码:
1 #include<cstdio> 2 #include<cstring> 3 #include<algorithm> 4 using namespace std; 5 const int N=1e4+5; 6 int n,k,ans,root,size,tot,d[N],dep[N],sz[N],f[N]; 7 int cnt,fro[N],to[N<<1],w[N<<1],nxt[N<<1]; 8 bool vis[N]; 9 10 inline int read() { 11 int x=0; char c=getchar(); 12 while(c<'0'||c>'9') c=getchar(); 13 while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-'0',c=getchar(); 14 return x; 15 } 16 void add(int x,int y,int z) { 17 to[++cnt]=y,w[cnt]=z,nxt[cnt]=fro[x]; fro[x]=cnt; 18 } 19 20 void findrt(int u,int fa) { 21 sz[u]=1; f[u]=0; 22 for(int i=fro[u];i;i=nxt[i]) { 23 int v=to[i]; 24 if(vis[v]||v==fa) continue; 25 findrt(v,u); 26 sz[u]+=sz[v],f[u]=max(f[u],sz[v]); 27 } 28 f[u]=max(f[u],size-sz[u]); 29 if(f[u]<f[root]) root=u; 30 } 31 void getdeep(int u,int fa) { 32 d[++tot]=dep[u]; 33 for(int i=fro[u];i;i=nxt[i]) { 34 int v=to[i]; 35 if(vis[v]||v==fa) continue; 36 dep[v]=dep[u]+w[i]; 37 getdeep(v,u); 38 } 39 } 40 int clac(int u) { 41 tot=0; getdeep(u,0); 42 sort(d+1,d+tot+1); 43 int sum=0,l=1,r=tot; 44 while(l<r) { 45 if(d[l]+d[r]<=k) sum+=r-l,l++; 46 else r--; 47 } 48 return sum; 49 } 50 void solve(int u) { 51 vis[u]=1; 52 dep[u]=0; ans+=clac(u); 53 for(int i=fro[u];i;i=nxt[i]) { 54 int v=to[i]; 55 if(vis[v]) continue; 56 dep[v]=w[i]; ans-=clac(v); 57 root=0; size=sz[v]; 58 findrt(v,0); solve(root); 59 } 60 } 61 62 int main() { 63 while(scanf("%d%d",&n,&k)&&(n||k)) { 64 cnt=0,ans=0; 65 memset(fro,0,sizeof(fro)); 66 memset(vis,0,sizeof(vis)); 67 for(int i=1;i<n;i++) { 68 int x=read(),y=read(),z=read(); 69 add(x,y,z); add(y,x,z); 70 } 71 root=0; size=f[0]=n; 72 findrt(1,0); solve(root); 73 printf("%d ",ans); 74 } 75 return 0; 76 }
其它例题
把路径长度对$3$取模后答案为$0,1,2$的路径条数分别保存为$t[0],t[1],t[2]$。
答案:$2 imes t[1] imes t[2]+t[0] imes t[0]$。
1 #include<bits/stdc++.h> 2 using namespace std; 3 const int N=2e4+5; 4 int n,ans,root,size,t[3],dep[N],sz[N],f[N]; 5 int cnt,to[N<<1],w[N<<1],nxt[N<<1],fro[N]; 6 bool vis[N]; 7 8 inline int read() { 9 int x=0; char c=getchar(); 10 while(c<'0'||c>'9') c=getchar(); 11 while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-'0',c=getchar(); 12 return x; 13 } 14 void add(int x,int y,int z) { 15 to[++cnt]=y,w[cnt]=z,nxt[cnt]=fro[x]; fro[x]=cnt; 16 } 17 int gcd(int a,int b) {return b?gcd(b,a%b):a;} 18 19 void findrt(int u,int fa) { 20 sz[u]=1; f[u]=0; 21 for(int i=fro[u];i;i=nxt[i]) { 22 int v=to[i]; 23 if(vis[v]||v==fa) continue; 24 findrt(v,u); 25 sz[u]+=sz[v]; f[u]=max(f[u],sz[v]); 26 } 27 f[u]=max(f[u],size-sz[u]); 28 if(f[u]<f[root]) root=u; 29 } 30 void query(int u,int fa) { 31 t[dep[u]]++; 32 for(int i=fro[u];i;i=nxt[i]) { 33 int v=to[i]; 34 if(vis[v]||v==fa) continue; 35 dep[v]=(dep[u]+w[i])%3; 36 query(v,u); 37 } 38 } 39 int calc(int u,int d0) { 40 dep[u]=d0; 41 t[0]=t[1]=t[2]=0; 42 query(u,0); 43 return t[0]*t[0]+2*t[1]*t[2]; 44 } 45 void solve(int u) { 46 ans+=calc(u,0); vis[u]=1; 47 for(int i=fro[u];i;i=nxt[i]) { 48 int v=to[i]; 49 if(vis[v]) continue; 50 ans-=calc(v,w[i]); 51 root=0,size=sz[v]; 52 findrt(v,0); solve(root); 53 } 54 } 55 56 int main() { 57 n=read(); 58 for(int i=1;i<n;i++) { 59 int x=read(),y=read(),z=read()%3; 60 add(x,y,z),add(y,x,z); 61 } 62 size=f[0]=n; 63 findrt(1,0); solve(root); 64 int t=gcd(ans,n*n); 65 printf("%d/%d ",ans/t,n*n/t); 66 }
如有错误、疑问请联系作者(见公告)!感谢。