点分治
点分治是树上分治的一种(树上分治还有边分治),常用于解决和树上路径有关的问题。
因为树上路径有一条性质:树上的任何路径,要么经过根节点$rt$要么就全部在$rt$的一颗子树上。
正确性显而易见:树上两点的路径是唯一的,如果两点在$rt$的同一子树上,则路径完全在一颗子树上,如果在$rt$的不同子树,则必然经过$rt$。
有了这条性质,我们就可以对树上的路径进行分治:先将经过$rt$的路径处理完,此时$rt$的各个子树上的路径就互不影响了,故可以递归分治。
但怎么使得分治均匀呢?如果随意选择根节点,则如果树退化成链,则递归层数为$N$,且每次操作的节点数目都会非常多。
所以此时,我们选择树的重心,这样可以使每一次分治后剩下的最大子树的大小下降最快(重心的最大子树最小),每一次分治我们都找重心,操作本身复杂度为$O(NlogN)$,可以使分治底层执行次数降到$O(NlogN)$(主定理)。通常情况下,节点的子树超过$2$棵,则复杂度往往会低于$O(NlogN)$
void getrt(int x,int fa){ siz[x]=1; maxs[x]=0; for(int i=head[x];i;i=nxt[i]){ if(fa==to[i]&&vis[to[i]]) continue; getrt(to[i],x); siz[x]+=siz[to[i]]; maxs[x]=max(maxs[x],siz[to[i]]); } maxs[x]=max(maxs[x],sum-siz[x]); if(maxs[x]<maxs[rt]) rt=x; }
这有什么用呢?
举个例子,如果我们要统计一棵树内各个长度的路径的个数,就可以用点分治来做。
首先,我们将树的重心设为根,求出树上各点到根的距离($O(N)$),并统计每个长度的路径的数量($O(N^2)$),然后枚举每一个子树,递归执行同样的操作,并在同一个数组上统计数量
算法的核心在于分治:
就着代码讲:
void Divid(int x) { ans+=solve(x,0); ①统计经过RT的所有路径(不一定合法,下文会重点讲) vis[x] = 1; for (int i = head[x];i;i = edges[i].net) 枚举所有子树 { edge v = edges[i]; if(vis[v.to]) continue; 防止遍历到上层 ans-=solve(v.to,edges[i].cost); 减掉①中统计的不合法路径* S = size[v.to]; root = 0; find(v.to,x); 找到新根 Divid(root); 按子树分治 } }
*上面代码中提到求出的“经过RT的路径”不一定合法,要减掉一部分,为什么呢?在统计经过rt的路径时,我们将所有点都两两匹配了,此时同一棵子树中的点当然不能配对。
如图,我们将$RT ightarrow B$和$RT ightarrow C$的路径合并了,其中$RT ightarrow A$部分的路径重复了
怎么去掉重复的部分呢?见代码,我们【将“经过A点的路径”(不一定合法)加上$RT ightarrow A$的长度】这一部分产生的贡献去掉
如图,过A点的路径为$B ightarrow A ightarrow C$,统计结果为$A ightarrow B+A ightarrow C+2*(RT ightarrow A)$(因为所有长度都加了$RT ightarrow A$)
这样,我们就求出了所有经过RT的路径的贡献
接下来,就是逐步细化实现了。对于大部分点分治的题目,上面分治的代码都是差不多的,题目的差异在于solve函数。
下面,我以洛谷P3806 【模板】点分治1为例介绍一下实现的过程 题解
此题的题意很简单,就是询问树上长度为k的路径是否存在。看到询问路径,就很容易想到点分治,故我们可以套上面的模板。
这里介绍一下solve函数的实现
首先,我们要求出从当前根节点到子数中所有点的距离,然后将这些距离组合配对,将所有的和统计到答案中即可
void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N^2)*/ tp=0; dis[x]=len; get_dis(x,0,len); for(int i=1;i<=tp;i++) for(int j=1;j<=tp;j++) if(i!=j) ans[st[i]+st[j]]+=w; }
这是求距离的代码
void get_dis(int x,int fa,int len){ if(len<=1e7) st[++tp]=len; for(int i=head[x];i;i=nxt[i]){ if(to[i]==fa||vis[to[i]]) continue; dis[to[i]]=len+val[i]; get_dis(to[i],x,len+val[i]); } }
于是,我们能得到以下代码,可以AC(是因为数据水)
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long LL; 4 const int MAXK=2e7,MAXN=1e4+7,MAXM=2e4+7; 5 inline void Max(int &x,int y){ 6 x=x>y?x:y; 7 } 8 int sz,head[MAXN],to[MAXM],nxt[MAXM],val[MAXM]; 9 inline void add(int x,int y,int z){ 10 nxt[++sz]=head[x]; head[x]=sz; to[sz]=y; val[sz]=z; 11 nxt[++sz]=head[y]; head[y]=sz; to[sz]=x; val[sz]=z; 12 } 13 int rt,siz[MAXN],maxson[MAXN],vis[MAXN],S; 14 void find(int x/*cur vertex*/,int fa/*father*/){/*find root*/ 15 siz[x]=1; 16 maxson[x]=0; 17 for(int i=head[x];i;i=nxt[i]){ 18 if(to[i]==fa||vis[to[i]]) 19 continue; 20 find(to[i],x); 21 siz[x]+=siz[to[i]]; 22 Max(maxson[x],siz[to[i]]); 23 } 24 Max(maxson[x],S-siz[x]); 25 if(maxson[x]<maxson[rt]) 26 rt=x; 27 } 28 int dis[MAXN],st[MAXN],tp; 29 void get_dis(int x,int fa,int len){ 30 if(len<=1e7) 31 st[++tp]=len; 32 for(int i=head[x];i;i=nxt[i]){ 33 if(to[i]==fa||vis[to[i]]) 34 continue; 35 dis[to[i]]=len+val[i]; 36 get_dis(to[i],x,len+val[i]); 37 } 38 } 39 int ans[MAXK]; 40 void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N^2)*/ 41 tp=0; 42 dis[x]=len; 43 get_dis(x,0,len); 44 for(int i=1;i<=tp;i++) 45 for(int j=1;j<=tp;j++) 46 if(i!=j) 47 ans[st[i]+st[j]]+=w; 48 } 49 int N,Q,K; 50 void divide(int x){ 51 solve(x,0,1); 52 vis[x]=1; 53 for(int i=head[x];i;i=nxt[i]){ 54 if(vis[to[i]]) 55 continue; 56 solve(to[i],val[i],-1); 57 S=siz[x]; 58 rt=0; 59 maxson[0]=N; 60 find(to[i],x); 61 divide(rt); 62 } 63 } 64 int main(){ 65 scanf("%d%d",&N,&Q); 66 for(int i=1;i<N;i++){ 67 int ii,jj,kk; 68 scanf("%d%d%d",&ii,&jj,&kk); 69 add(ii,jj,kk); 70 } 71 S=N; 72 maxson[0]=N; 73 rt=0; 74 find(1,0); 75 divide(rt); 76 while(Q--){ 77 scanf("%d",&K); 78 puts(ans[K]?"AYE":"NAY"); 79 } 80 return 0; 81 }
分析代码可以发现,实际上代码的时间复杂度为$Theta(N^2 log N)$,在较强的数据中是会TLE的,于是我们要优化
我们发现我们对于所有可能的询问,都统计了答案,这其实是一种冗余。题目中的m非常小,我们其实可以根据询问来统计答案,效率会提高两个数量级
于是,我们很容易想到将询问离线,然后每次在表内统计一份结果,并且枚举所有的询问,在表内查询之前是否得到过$答案-当前结果$的值
新的solve函数
void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N*M)*/ ++timeclock; tp=0; dis[x]=len; get_dis(x,0,len); for(int i=1;i<=tp;i++) for(int j=1;j<=Q;j++){ int ii=qry[j]-st[i]; if(ii<0||date[ii]!=timeclock||(b[ii]==1&&ii==st[i])) continue; ans[j]+=w; } }
最后的代码也就很简单了,用时为原来的几十分之一
1 #include<bits/stdc++.h> 2 using namespace std; 3 typedef long long LL; 4 const int MAXK=2e7,MAXN=1e4+7,MAXM=2e4+7,MAXQ=1e2+7; 5 inline void Max(int &x,int y){ 6 x=x>y?x:y; 7 } 8 int sz,head[MAXN],to[MAXM],nxt[MAXM],val[MAXM]; 9 inline void add(int x,int y,int z){ 10 nxt[++sz]=head[x]; head[x]=sz; to[sz]=y; val[sz]=z; 11 nxt[++sz]=head[y]; head[y]=sz; to[sz]=x; val[sz]=z; 12 } 13 int rt,siz[MAXN],maxson[MAXN],vis[MAXN],S; 14 void find(int x/*cur vertex*/,int fa/*father*/){/*find root*/ 15 siz[x]=1; 16 maxson[x]=0; 17 for(int i=head[x];i;i=nxt[i]){ 18 if(to[i]==fa||vis[to[i]]) 19 continue; 20 find(to[i],x); 21 siz[x]+=siz[to[i]]; 22 Max(maxson[x],siz[to[i]]); 23 } 24 Max(maxson[x],S-siz[x]); 25 if(maxson[x]<maxson[rt]) 26 rt=x; 27 } 28 int dis[MAXN],st[MAXN],tp; 29 int qry[MAXQ],ans[MAXQ],date[MAXK],b[MAXK],timeclock; 30 int N,Q,K; 31 void get_dis(int x,int fa,int len){ 32 if(len<=1e7){ 33 st[++tp]=len; 34 if(date[len]==timeclock) 35 b[len]++; 36 else{ 37 b[len]=1; 38 date[len]=timeclock; 39 } 40 } 41 for(int i=head[x];i;i=nxt[i]){ 42 if(to[i]==fa||vis[to[i]]) 43 continue; 44 dis[to[i]]=len+val[i]; 45 get_dis(to[i],x,len+val[i]); 46 } 47 } 48 void solve(int x,int len/*start dis*/,int w/*weight*/){/*O(N*M)*/ 49 ++timeclock; 50 tp=0; 51 dis[x]=len; 52 get_dis(x,0,len); 53 for(int i=1;i<=tp;i++) 54 for(int j=1;j<=Q;j++){ 55 int ii=qry[j]-st[i]; 56 if(ii<0||date[ii]!=timeclock||(b[ii]==1&&ii==st[i])) 57 continue; 58 ans[j]+=w; 59 } 60 } 61 void divide(int x){ 62 solve(x,0,1); 63 vis[x]=1; 64 for(int i=head[x];i;i=nxt[i]){ 65 if(vis[to[i]]) 66 continue; 67 solve(to[i],val[i],-1); 68 S=siz[x]; 69 rt=0; 70 maxson[0]=N; 71 find(to[i],x); 72 divide(rt); 73 } 74 } 75 int main(){ 76 scanf("%d%d",&N,&Q); 77 for(int i=1;i<N;i++){ 78 int ii,jj,kk; 79 scanf("%d%d%d",&ii,&jj,&kk); 80 add(ii,jj,kk); 81 } 82 for(int i=1;i<=Q;i++){ 83 scanf("%d",qry+i); 84 } 85 S=N; 86 maxson[0]=N; 87 rt=0; 88 find(1,0); 89 divide(rt); 90 for(int i=1;i<=Q;i++) 91 puts(ans[i]?"AYE":"NAY"); 92 return 0; 93 }