开始学树的点分治了,写一篇博客记录。
推荐博客:入门 https://blog.csdn.net/a_forever_dream/article/details/81778649
点分治+动态点分治:https://www.cnblogs.com/bztMinamoto/p/9489473.html
题目:https://www.cnblogs.com/zhenglw/p/10658210.html
树的点分治,顾名思义,就是在树上基于点的分治来解决树的问题。具体来说就是选择树的重心把树进行拆分为几棵子树分治求解,再把子树的结果合并得到这棵树的结果。至于为什么要选择树的重心来进行拆分呢?其实可以证明以树的重心拆分的分治时间复杂度为O(logn)的。
说白了,树的点分治只是提供了一种树的分治方法,但是具体分治后要求什么?写什么?还得看题目问的是什么。
当然虽然分治后要求千变万化,但是还是有一些经典题目来提供一些套路让我们学习的。
模板题:洛谷P3806 点分治1
注意到多组询问且k不等,那么得先预处理树上所有两点间距离放到桶子里,再O(1)处理询问。
那么怎么通过点分治求出树上所有两点间距离呢?这里是通过容斥原理,先求出所有点到根节点距离然后通过双重循环得到所有经过根节点的路径,但是我们注意到如果两个点都位于根节点的同一棵子树下的话,这两个点的路径是不应该经过根节点的,那么我们就得想办法去掉这种情况。做法是:对所有子树也做一次上诉操作但是要设置一个初始距离为len[i],这样的话子树内两两结点的距离就是上诉的不经过根节点的不合理情况。减去即可。
#pragma comment(linker,"/STACK:102400000,102400000") #include<bits/stdc++.h> using namespace std; const int N=1e5+10; const int INF=0x3f3f3f3f; int n,m,Min,rt,Size,num; int sz[N],mson[N],dis[N],vis[N]; int sum[10000000+10]; int cnt=1,head[N],nxt[N<<1],to[N<<1],len[N<<1]; void add_edge(int x,int y,int z) { nxt[++cnt]=head[x]; to[cnt]=y; len[cnt]=z; head[x]=cnt; } void getroot(int x,int fa) { sz[x]=1; mson[x]=0; //sz[x]表示x子树大小 mson[x]表示x的最大子树大小 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y] || y==fa) continue; getroot(y,x); sz[x]+=sz[y]; mson[x]=max(mson[x],sz[y]); } mson[x]=max(mson[x],Size-sz[x]); //除去x子树外的其他点也组成一棵树 if (Min>mson[x]) Min=mson[x],rt=x; //记录树的重心 } void getdis(int x,int fa,int d) { dis[++num]=d; //记录x子树下所有点到x的距离 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y] || y==fa) continue; getdis(y,x,d+len[i]); } } //点分治的步骤不太会改变,solve函数就是点分治时要处理的问题 void solve(int x,int d,int opt) { //求距离x点初始距离为d 的所有距离放到sum桶子里 num=0; getdis(x,0,d); if (opt==1) { for (int i=1;i<=num;i++) for (int j=i+1;j<=num;j++) sum[dis[i]+dis[j]]++; } else { //容斥原理要减去 for (int i=1;i<=num;i++) for (int j=i+1;j<=num;j++) sum[dis[i]+dis[j]]--; } } void divide(int x) { //点分治 vis[x]=1; //vis[x]代表该点已经分治 solve(x,0,1); //加上整颗树的 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y]) continue; solve(y,len[i],0); //减去子树的,这里注意是分治前的子树 Min=INF; rt=0; Size=sz[y]; getroot(y,x); //初始化找y子树的重心 divide(rt); //继续分治 } } int main() { cin>>n>>m; for (int i=1;i<n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); add_edge(x,y,z); add_edge(y,x,z); } Min=INF; rt=0; Size=n; getroot(1,0); divide(rt); //分治求出所有树上两点距离 塞到sum桶子里 for (int i=1;i<=m;i++) { int x; scanf("%d",&x); if (sum[x]) puts("AYE"); else puts("NAY"); } return 0; }
同样是模板题:POJ1741 Tree
这道题和上面的差不多,不过只有一组询问且要求的是长度<=K的点对,那么我们就不用把所有距离记录下来(主要是时间也慢)。还是容斥原理,我们通过排序后扫描(这个可能不好理解建议看代码)的做法得到根的答案,同样方法得到各个子树的答案并减去。

#pragma comment(linker,"/STACK:102400000,102400000") #include<iostream> #include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int N=1e5+10; const int INF=0x3f3f3f3f; int n,k,Min,rt,Size,num,ans; int sz[N],mson[N],dis[N],vis[N]; int cnt=1,head[N],nxt[N<<1],to[N<<1],len[N<<1]; void add_edge(int x,int y,int z) { nxt[++cnt]=head[x]; to[cnt]=y; len[cnt]=z; head[x]=cnt; } void getroot(int x,int fa) { sz[x]=1; mson[x]=0; //sz[x]表示x子树大小 mson[x]表示x的最大子树大小 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y] || y==fa) continue; getroot(y,x); sz[x]+=sz[y]; mson[x]=max(mson[x],sz[y]); } mson[x]=max(mson[x],Size-sz[x]); //除去x子树外的其他点也组成一棵树 if (Min>mson[x]) Min=mson[x],rt=x; //记录树的重心 } void getdis(int x,int fa,int d) { dis[++num]=d; //记录x子树下所有点到x的距离 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y] || y==fa) continue; getdis(y,x,d+len[i]); } } //点分治的步骤不太会改变,solve函数就是点分治时要处理的问题 void solve(int x,int d,int opt) { //求距离x点初始距离为d 的所有距离放到sum桶子里 num=0; getdis(x,0,d); sort(dis+1,dis+num+1); int tmp=0; for (int l=1,r=num;l<r;) //对于每一个l求出多少个l右边的数r满足dis[l]+dis[r]<=k if (dis[l]+dis[r]<=k) tmp+=r-l,l++; else r--; if (opt==1) ans+=tmp; else ans-=tmp; } void divide(int x) { //点分治 vis[x]=1; //vis[x]代表该点已经分治 solve(x,0,1); //加上整颗树的 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y]) continue; solve(y,len[i],0); //减去子树的,这里注意是分治前的子树 Min=INF; rt=0; Size=sz[y]; getroot(y,x); //初始化找y子树的重心 divide(rt); //继续分治 } } int main() { while (cin>>n>>k && n) { cnt=1; memset(head,0,sizeof(head)); for (int i=1;i<n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); add_edge(x,y,z); add_edge(y,x,z); } for (int i=1;i<=n;i++) vis[i]=0; ans=0; Min=INF; rt=0; Size=n; getroot(1,0); divide(rt); cout<<ans<<endl; } return 0; }
洛谷P3634 [国家集训队]聪聪可可
这道题也是求任两点间距离,不过要求的是距离是3的倍数的点对。那么想到对于dis数组对3取模,那么在该棵树的距离为3倍数的数量为cnt[0]*cnt[0]+2*cnt[1]*cnt[2],容斥计算即可。这里有一个小细节:为什么这样计算能包括像(x,x)这样的点对呢?因为这里的dis数组是记录以x为根的树下所有点(包括x点)到x的距离,所以会有dis[x]=0这个值,这个值会被算在cnt[0]里面。

#pragma comment(linker,"/STACK:102400000,102400000") #include<iostream> #include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int N=1e5+10; const int INF=0x3f3f3f3f; int n,k,Min,rt,Size,num,ans; int sz[N],mson[N],dis[N],vis[N],c[5]; int gcd(int a,int b) { return b==0 ? a : gcd(b,a%b); } int cnt=1,head[N],nxt[N<<1],to[N<<1],len[N<<1]; void add_edge(int x,int y,int z) { nxt[++cnt]=head[x]; to[cnt]=y; len[cnt]=z; head[x]=cnt; } void getroot(int x,int fa) { sz[x]=1; mson[x]=0; //sz[x]表示x子树大小 mson[x]表示x的最大子树大小 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y] || y==fa) continue; getroot(y,x); sz[x]+=sz[y]; mson[x]=max(mson[x],sz[y]); } mson[x]=max(mson[x],Size-sz[x]); //除去x子树外的其他点也组成一棵树 if (Min>mson[x]) Min=mson[x],rt=x; //记录树的重心 } void getdis(int x,int fa,int d) { dis[++num]=d; //记录以x为根的树下所有点(包括x点)到x的距离 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y] || y==fa) continue; getdis(y,x,d+len[i]); } } //点分治的步骤不太会改变,solve函数就是点分治时要处理的问题 void solve(int x,int d,int opt) { //求距离x点初始距离为d 的所有距离放到sum桶子里 num=0; getdis(x,0,d); int tmp=0; c[0]=c[1]=c[2]=0; for (int i=1;i<=num;i++) c[dis[i]%3]++; tmp=c[0]*c[0]+2*c[1]*c[2]; if (opt==1) ans+=tmp; else ans-=tmp; } void divide(int x) { //点分治 vis[x]=1; //vis[x]代表该点已经分治 solve(x,0,1); //加上整颗树的 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y]) continue; solve(y,len[i],0); //减去子树的,这里注意是分治前的子树 Min=INF; rt=0; Size=sz[y]; getroot(y,x); //初始化找y子树的重心 divide(rt); //继续分治 } } int main() { cin>>n; for (int i=1;i<n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); add_edge(x,y,z); add_edge(y,x,z); } ans=0; Min=INF; rt=0; Size=n; getroot(1,0); divide(rt); int g=gcd(ans,n*n); cout<<ans/g<<"/"<<n*n/g<<endl; return 0; }
Codeforces-161D Distance in Tree
懒得想新做法,在POJ1741的基础上直接求出ans(dis<=k)和ans(dis<=k-1)然后相减即可。

#pragma comment(linker,"/STACK:102400000,102400000") #include<iostream> #include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int N=1e5+10; const int INF=0x3f3f3f3f; int n,k,Min,rt,Size,num,ans; int sz[N],mson[N],dis[N],vis[N]; int cnt=1,head[N],nxt[N<<1],to[N<<1],len[N<<1]; void add_edge(int x,int y,int z) { nxt[++cnt]=head[x]; to[cnt]=y; len[cnt]=z; head[x]=cnt; } void getroot(int x,int fa) { sz[x]=1; mson[x]=0; //sz[x]表示x子树大小 mson[x]表示x的最大子树大小 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y] || y==fa) continue; getroot(y,x); sz[x]+=sz[y]; mson[x]=max(mson[x],sz[y]); } mson[x]=max(mson[x],Size-sz[x]); //除去x子树外的其他点也组成一棵树 if (Min>mson[x]) Min=mson[x],rt=x; //记录树的重心 } void getdis(int x,int fa,int d) { dis[++num]=d; //记录x子树下所有点到x的距离 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y] || y==fa) continue; getdis(y,x,d+len[i]); } } //点分治的步骤不太会改变,solve函数就是点分治时要处理的问题 void solve(int x,int d,int opt) { //求距离x点初始距离为d 的所有距离放到sum桶子里 num=0; getdis(x,0,d); sort(dis+1,dis+num+1); int tmp=0; for (int l=1,r=num;l<r;) //对于每一个l求出多少个l右边的数r满足dis[l]+dis[r]<=k if (dis[l]+dis[r]<=k) tmp+=r-l,l++; else r--; if (opt==1) ans+=tmp; else ans-=tmp; } void divide(int x) { //点分治 vis[x]=1; //vis[x]代表该点已经分治 solve(x,0,1); //加上整颗树的 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y]) continue; solve(y,len[i],0); //减去子树的,这里注意是分治前的子树 Min=INF; rt=0; Size=sz[y]; getroot(y,x); //初始化找y子树的重心 divide(rt); //继续分治 } } int main() { cin>>n>>k; for (int i=1;i<n;i++) { int x,y,z; scanf("%d%d",&x,&y); z=1; add_edge(x,y,z); add_edge(y,x,z); } for (int i=1;i<=n;i++) vis[i]=0; int Ans=0; ans=0; Min=INF; rt=0; Size=n; getroot(1,0); divide(rt); Ans+=ans; for (int i=1;i<=n;i++) vis[i]=0; k--; ans=0; Min=INF; rt=0; Size=n; getroot(1,0); divide(rt); Ans-=ans; cout<<Ans<<endl; return 0; }
洛谷P4149 [IOI2011]Race
这道题不能用上面几题的容斥原理先加后减,因为这里要求的不是满足条件的路径条数,而是要求满足条件的路径最短,这个并不满足可加性。所以我们可以换一种解决办法:这种解决办法类似于树形dp求树的直径,我们用个桶子c[i]代表距离根节点x距离为i的最少路径条数。那么我们先对新的子树求dis数组,然后把新的子树和旧的子树结果合并并尝试更新答案,更新答案后把新的子树结果合并到桶子里。这样就能实现两两子树都有机会组合成为答案。
这里我们思考一个问题:为什么上面几题不同这种方法呢?因为上诉几题都需要统计任两个点对之间的距离才能得到答案。而用桶子来统计答案会损失数量这个信息,而对于要求统计数量的题目这是致命的。而若不用桶子,那么用什么来保存上诉所说的 所有旧的子树结果 且不损失数量信息?如果暴力保存空间时间上都爆炸,所以还是用容斥原理更为简易可行。

#pragma comment(linker,"/STACK:102400000,102400000") #include<iostream> #include<cstdio> #include<algorithm> #include<cstring> using namespace std; const int N=2e5+10; const int INF=0x3f3f3f3f; int n,k,Min,rt,Size,num,ans=INF; int sz[N],mson[N],vis[N],c[1000010]; struct dat{ int d1,d2; }dis[N]; int cnt=1,head[N],nxt[N<<1],to[N<<1],len[N<<1]; void add_edge(int x,int y,int z) { nxt[++cnt]=head[x]; to[cnt]=y; len[cnt]=z; head[x]=cnt; } void getroot(int x,int fa) { sz[x]=1; mson[x]=0; //sz[x]表示x子树大小 mson[x]表示x的最大子树大小 for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y] || y==fa) continue; getroot(y,x); sz[x]+=sz[y]; mson[x]=max(mson[x],sz[y]); } mson[x]=max(mson[x],Size-sz[x]); //除去x子树外的其他点也组成一棵树 if (Min>mson[x]) Min=mson[x],rt=x; //记录树的重心 } void getdis(int x,int fa,int d,int dep) { if (d>k) return; ++num; dis[num].d1=d; dis[num].d2=dep; for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y] || y==fa) continue; getdis(y,x,d+len[i],dep+1); } } //点分治的步骤不太会改变,solve函数就是点分治时要处理的问题 void solve(int x) { //求距离x点初始距离为d 的所有距离放到sum桶子里 num=0; for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y]) continue; int lst=num; getdis(y,x,len[i],1); for (int j=lst+1;j<=num;j++) ans=min(ans,dis[j].d2+c[k-dis[j].d1]); for (int j=lst+1;j<=num;j++) c[dis[j].d1]=min(c[dis[j].d1],dis[j].d2); } for (int i=1;i<=num;i++) c[dis[i].d1]=INF; c[0]=0; //还原回去 } void divide(int x) { //点分治 vis[x]=1; //vis[x]代表该点已经分治 solve(x); for (int i=head[x];i;i=nxt[i]) { int y=to[i]; if (vis[y]) continue; Min=INF; rt=0; Size=sz[y]; getroot(y,x); //初始化找y子树的重心 divide(rt); //继续分治 } } int main() { cin>>n>>k; for (int i=1;i<n;i++) { int x,y,z; scanf("%d%d%d",&x,&y,&z); x++; y++; add_edge(x,y,z); add_edge(y,x,z); } memset(c,0x3f,sizeof(c)); c[0]=0; Min=INF; rt=0; Size=n; getroot(1,0); divide(rt); if (ans>=INF) puts("-1"); else cout<<ans<<endl; return 0; }