点分治是一种常用于处理树上点对关系的分治算法。
一、算法介绍
提到点分治,我们先来看一道例题:洛谷P3806 【模板】点分治1
题意:多组询问,边有边权,询问树上是否存在距离为$ k $的点对。$ n leq 10^4, k leq 10^7 $
我们显然有一种暴力算法:对于每个询问,枚举每个点对判断距离是否等于给定的$ k $,复杂度$ O(mn^2) $,但这样的复杂显然太高了,我们需要更快的算法。
我们发现如果钦定了树根,那么$ dist(x,y)=dep[x]+dep[y]-2 imes dep[lca] $,于是我们可以尝试枚举$ lca $,然后搜索子树中的每个节点,对于遍历到的当前结点$ now $,寻找是否存在一个已遍历节点$ x $使$ dep[x]=k+2 imes dep[lca]-dep[y] $,这个我们可以用一个桶存下已遍历结点到$ lca $的距离。然而我们发现状况还是没什么变化,复杂度依然是$ O(mn^2) $的。
然而,我们发现在计算$ lca $的贡献的时候,我们相当于一次统计了经过$ lca $的所有路径,于是我们接下来只需对没有经过$ lca$的路径,也就是$ lca $的每个子树分别统计,这就是所谓的“分治”,将问题分解成几个子问题分别处理。
但是,这和前面的算法有什么区别?这样的算法在某些数据下还是不够优秀,比如当树是一条链的情况下,每次$ O(n) $处理完当前节点后,然后往子节点走,这样时间复杂度还是会被卡到$ O(n^2) $。虽然如此,我们发现造成时间复杂度退化的原因是,我们在处理每个子树时钦定的根节点不够优秀。这使我们就会想到无根树上有一个性质极其优秀的点:重心。重心满足以它为根,它的每个子树的大小都不会超过$ frac{n}{2} $的性质,如果我们每次开始处理一块子树,都以这块子树的重心为根开始处理,每次能使问题的规模下降以半,最多分治$ O(log{n}) $层。在上面的例题中,总复杂度为$ O(mnlog{n}) $
现在我们总结一下点分治的基本思路:
1、先计算出当前统计的这一块树的答案。(注意:计算答案时需保证点对是在$ lca $不同的子树中,这样才能保证路径经过$ lca $,如例题中,我们处理完一个子树后才把该子树的信息加入桶中)
2、找出这块树的重心。
3、把重心删除,递归处理该树断开成的几棵子树。
核心代码:
void divide(int now) { solve(now);//计算以now为根的这棵树的答案 vis[now]=1;//删除节点now for(int i=son of now){ int nxt=getroot(i);//计算儿子i所在的这棵子树的重心 divide(nxt);//递归分治处理 } }
以下是例题的完整代码:
#include<cstdio> #include<cstdlib> #include<cstring> #include<cmath> #include<ctime> #include<algorithm> #define ll long long #define maxn 10010 inline ll read() { ll x=0; char c=getchar(),f=1; for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1; for(;'0'<=c&&c<='9';c=getchar())x=x*10+c-'0'; return x*f; } inline void write(ll x) { char buf[20],len; len=0; if(x<0)putchar('-'),x=-x; for(;x;x/=10)buf[len++]=x%10+'0'; if(!len)putchar('0'); else while(len)putchar(buf[--len]); } inline void writesp(ll x){write(x); putchar(' ');} inline void writeln(ll x){write(x); putchar(' ');} struct edge{ int to,nxt,d; }e[2*maxn]; int fir[maxn],dist[maxn],size[maxn],vis[maxn]; int id[maxn]; int mark[10000010]; int q[110],ok[110]; int n,m,tot; void add_edge(int x,int y,int z){e[tot].to=y; e[tot].d=z; e[tot].nxt=fir[x]; fir[x]=tot++;} void search(int now,int fa) { id[++tot]=now; size[now]=1; for(int i=fir[now];~i;i=e[i].nxt) if(e[i].to!=fa&&!vis[e[i].to]){ dist[e[i].to]=dist[now]+e[i].d; search(e[i].to,now); size[now]+=size[e[i].to]; } } void solve(int now) { tot=1; id[1]=now; dist[now]=0; mark[0]=1; int last=1; for(int i=fir[now];~i;i=e[i].nxt) if(!vis[e[i].to]){ dist[e[i].to]=e[i].d; search(e[i].to,now); for(int j=1;j<=m;j++){ if(ok[j])continue; for(int k=last+1;k<=tot;k++) if(dist[id[k]]<=q[j])ok[j]|=mark[q[j]-dist[id[k]]]; } for(int j=last+1;j<=tot;j++) if(dist[id[j]]<=10000000)mark[dist[id[j]]]=1; last=tot; } for(int i=1;i<=tot;i++) if(dist[id[i]]<=10000000)mark[dist[id[i]]]=0; } int getroot(int now,int fa,int S) { int mx=0; size[now]=1; for(int i=fir[now];~i;i=e[i].nxt) if(e[i].to!=fa&&!vis[e[i].to]){ int t=getroot(e[i].to,now,S); if(t)return t; size[now]+=size[e[i].to]; if(size[e[i].to]>mx)mx=size[e[i].to]; } if(S-size[now]>mx)mx=S-size[now]; if(mx*2<=S)return now; else return 0; } void divide(int now) { solve(now); vis[now]=1; for(int i=fir[now];~i;i=e[i].nxt) if(!vis[e[i].to]){ int rt=getroot(e[i].to,now,size[e[i].to]); divide(rt); } } int main() { n=read(); m=read(); memset(fir,255,sizeof(fir)); tot=0; for(int i=1;i<n;i++){ int x=read(),y=read(),z=read(); add_edge(x,y,z); add_edge(y,x,z); } for(int i=1;i<=m;i++) q[i]=read(); int rt=getroot(1,-1,n); divide(rt); for(int i=1;i<=m;i++) puts(ok[i]?"AYE":"NAY"); return 0; }
二、练习
题意:给一棵有边权的树求一条长度最短的距离为$ K $的路径。
显然是一道点分治裸题,在统计每个$ lca $的答案时用一个桶记录到$ lca $距离一定的结点的最小深度,然后其他做法与上题基本相同。
代码:
// luogu-judger-enable-o2 #include<cstdio> #include<cstring> #include<cmath> #include<cstdlib> #include<ctime> #include<algorithm> #define ll long long #define inf 0x3f3f3f3f #define maxn 200010 inline ll read() { ll x=0; char c=getchar(),f=1; for(;c<'0'||'9'<c;c=getchar())if(c=='-')f=-1; for(;'0'<=c&&c<='9';c=getchar())x=x*10+c-'0'; return x*f; } inline void write(ll x) { static char buf[20],len; len=0; if(x<0)x=-x,putchar('-'); for(;x;x/=10)buf[len++]=x%10+'0'; if(!len)putchar('0'); else while(len)putchar(buf[--len]); } inline void writesp(ll x){write(x); putchar(' ');} inline void writeln(ll x){write(x); putchar(' ');} struct edge{ int to,nxt,d; }e[2*maxn]; int fir[maxn],size[maxn],dep[maxn],vis[maxn]; ll dist[maxn]; int id[maxn],mn[1000010]; int n,m,tot,ans; void add_edge(int x,int y,int z){e[tot].to=y; e[tot].d=z; e[tot].nxt=fir[x]; fir[x]=tot++;} void search(int now,int fa) { id[++tot]=now; size[now]=1; for(int i=fir[now];~i;i=e[i].nxt) if(e[i].to!=fa&&!vis[e[i].to]){ dist[e[i].to]=dist[now]+e[i].d; dep[e[i].to]=dep[now]+1; search(e[i].to,now); size[now]+=size[e[i].to]; } } void solve(int now) { mn[0]=0; tot=0; int last=1; for(int i=fir[now];~i;i=e[i].nxt) if(!vis[e[i].to]){ dist[e[i].to]=e[i].d; dep[e[i].to]=1; search(e[i].to,now); for(int j=last;j<=tot;j++) if(dist[id[j]]<=m)ans=std::min(ans,dep[id[j]]+mn[m-dist[id[j]]]); for(int j=last;j<=tot;j++) if(dist[id[j]]<=m)mn[dist[id[j]]]=std::min(mn[dist[id[j]]],dep[id[j]]); last=tot+1; } for(int i=1;i<=tot;i++) if(dist[id[i]]<=m)mn[dist[id[i]]]=inf; mn[0]=inf; } int getroot(int now,int fa,int S) { // printf("%d %d %d **** ",now,fa,S); size[now]=1; int mx=0; for(int i=fir[now];~i;i=e[i].nxt) if(e[i].to!=fa&&!vis[e[i].to]){ int tmp=getroot(e[i].to,now,S); if(~tmp)return tmp; size[now]+=size[e[i].to]; if(size[e[i].to]>mx)mx=size[e[i].to]; } if(S-size[now]>mx)mx=S-size[now]; if(mx<<1<=S)return now; else return -1; } void divide(int now) { // writeln(now); // system("pause"); solve(now); vis[now]=1; for(int i=fir[now];~i;i=e[i].nxt) if(!vis[e[i].to]){ int nxt=getroot(e[i].to,now,size[e[i].to]); divide(nxt); } } int main() { n=read(); m=read(); memset(fir,255,sizeof(fir)); tot=0; for(int i=1;i<n;i++){ int x=read(),y=read(),z=read(); add_edge(x,y,z); add_edge(y,x,z); } memset(mn,0x3f,sizeof(mn)); ans=inf; int init=getroot(0,-1,n); divide(init); writeln(ans!=inf?ans:-1); return 0; }