好消息,今天我学会了点分治。所谓点分治,就是将树上的需要点统计的N^2的问题强行转化为NlogN。
本来事情是这样的:
1 inline void dfs(int x,int fa){ 2 sz[x]=1; 3 for(int i=beg[x];i;i=nex[i]){ 4 int t=to[i]; 5 ………… 6 //统计这个子树,复杂度是这个子树的大小 7 dfs(t,x); 8 sz[x]+=sz[t]; 9 } 10 }
一旦树退化成链,这样遍历就是N^2的复杂度。遍历的时间效率取决于子树的大小。
于是便想到了重心。
重心的定义:树上的一个点使其所有子树大小的最大值最小。
由平均值原理已知其最大的子数大小也小于等于N/2,这样便保证了稳定的logN。
具体实现如下:
inline void getrt(int x,int fa){//找重心 sz[x]=1;son[x]=0; //sz数组记录这个子树的大小,son数组记录其最大子树的大小。 for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(vis[t]||t==fa)continue; //vis数组用于判断这个点是否到过。 getrt(t,x); sz[x]+=sz[t]; if(sz[t]>son[x])son[x]=sz[t]; } if(son[x]<size-sz[x])son[x]=size-sz[x]; if(mx>son[x])mx=son[x],rt=x; } inline void solve(int x){ vis[x]=1; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(vis[t])continue; ………… //子树大小的复杂度来处理 mx=inf,rt=0,size=sz[t]; getrt(t,x); //找这个子树的重心 solve(rt); //处理这棵子树 } }
我们来看一个例题。
例题1:P2634 [国家集训队]聪聪可可
网址:https://www.luogu.com.cn/problem/P2634
对于每个点,统计过他的边的对三取模的个数。看代码。
#include<bits/stdc++.h> using namespace std; #define inf 1e9 const int maxn=20005; int beg[maxn],nex[maxn*2],to[maxn*2],w[maxn*2],e; inline void add(int x,int y,int z){ e++;nex[e]=beg[x]; beg[x]=e;to[e]=y;w[e]=z; } inline int read(){ int x=0,f=1; char c=getchar(); while(c>'9'||c<'0'){ if(c=='-')f=-1; c=getchar(); } while(c>='0'&&c<='9'){ x=(x<<1)+(x<<3)+c-'0'; c=getchar(); } return x*f; } int n,mx,rt,size,ans; int sz[maxn],son[maxn]; int vis[maxn],tmp[5],dp[maxn][5]; inline void getrt(int x,int fa){ sz[x]=1;son[x]=0; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(t==fa||vis[t])continue; getrt(t,x); sz[x]+=sz[t]; if(sz[t]>son[x])son[x]=sz[t]; } if(son[x]<size-sz[x])son[x]=size-sz[x]; if(son[x]<mx)mx=son[x],rt=x; } inline void calc(int x,int fa,int val){ tmp[val%3]++; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(t==fa||vis[t])continue; calc(t,x,w[i]+val); } } inline void divide(int x){ vis[x]=1; int tp[5]; tp[0]=tp[1]=tp[2]=0; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(vis[t])continue; tmp[0]=tmp[1]=tmp[2]=0; calc(t,x,w[i]); dp[x][0]=dp[x][0]+2*tmp[0]+2*tmp[0]*tp[0]+2*tmp[1]*tp[2]+2*tmp[2]*tp[1]; dp[x][1]=dp[x][1]+2*tmp[1]+2*tmp[0]*tp[1]+2*tmp[1]*tp[0]+2*tmp[2]*tp[2]; dp[x][2]=dp[x][2]+2*tmp[2]+2*tmp[0]*tp[2]+2*tmp[1]*tp[1]+2*tmp[2]*tp[0]; tp[0]+=tmp[0],tp[1]+=tmp[1],tp[2]+=tmp[2]; mx=inf,rt=0,size=sz[t]; getrt(t,x); divide(rt); } dp[x][0]++; ans+=dp[x][0]; //printf("%d %d ",x,dp[x][0]); } inline int gcd(int x,int y){ if(y==0)return x; return gcd(y,x%y); } int main(){ cin>>n; int x,y,z; for(int i=1;i<n;i++){ x=read(),y=read(),z=read(); add(x,y,z);add(y,x,z); } mx=inf,rt=0,size=n; getrt(1,0); divide(rt); int t=gcd(ans,n*n); printf("%d/%d ",ans/t,n*n/t); return 0; }
很板子的一道题哈!
例题2:P3806 【模板】点分治1
网址:https://www.luogu.com.cn/problem/P3806
虽然说是一个模板题,但其实并不是那么模板……
数据加强之后,N^2的时间复杂度是过不了的。注意到m只有一百,考虑先输入m个数,离线处理。
代码如下:
#include<bits/stdc++.h> using namespace std; #define inf 0x3f3f3f3f const int maxn=25000; inline int read(){ int x=0,f=1; char c=getchar(); while(c>'9'||c<'0'){ if(c=='-')f=-1; c=getchar(); } while(c>='0'&&c<='9'){ x=(x<<1)+(x<<3)+c-'0'; c=getchar(); } return x*f; } int beg[maxn],nex[maxn],to[maxn],w[maxn],e; void add(int x,int y,int z){ e++;nex[e]=beg[x]; beg[x]=e;to[e]=y;w[e]=z; } int n,m; int mx,size,rt; int sz[maxn],son[maxn]; int vis[maxn],ans[20000005],top,st[maxn]; inline void getrt(int x,int fa){ sz[x]=1;son[x]=0; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(t==fa||vis[t])continue; getrt(t,x); sz[x]+=sz[t]; if(sz[t]>son[x])son[x]=sz[t]; } if(size-sz[x]>son[x])son[x]=size-sz[x]; if(son[x]<mx)mx=son[x],rt=x; } inline void query(int pos,int fa,int have){ st[++top]=have; for(int i=beg[pos];i;i=nex[i]){ int t=to[i]; if(t==fa||vis[t])continue; query(t,pos,have+w[i]); } } int q[maxn],p,ask[maxn],test[maxn]; inline void solve(int pos){ p=0;ans[0]=1; for(int i=beg[pos];i;i=nex[i]){ int t=to[i]; if(vis[t])continue; top=0; query(t,pos,w[i]); for(int j=top;j;j--) for(int k=1;k<=m;k++) if(ask[k]>=st[j])test[k]|=ans[ask[k]-st[j]]; for(int j=1;j<=top;j++) q[++p]=st[j],ans[st[j]]=1; } for(int i=1;i<=p;i++) ans[q[i]]=0; } inline void devide(int x){ vis[x]=1; solve(x); for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(vis[t])continue; mx=inf,size=sz[t],rt=0; getrt(t,0); devide(rt); } } int main(){ n=read(),m=read(); int x,y,z; for(int i=1;i<n;i++){ x=read(),y=read(),z=read(); add(x,y,z); add(y,x,z); } for(int i=1;i<=m;i++) ask[i]=read(); mx=inf,size=n,rt=0; getrt(1,0); devide(rt); for(int i=1;i<=m;i++) puts(test[i]?"AYE":"NAY"); return 0; }
请注意ans数组要尽量开大一点,不然会RE,QAQ。
例题3:P4178 Tree
网址:https://www.luogu.com.cn/problem/P4178
我个人觉得这道题比例1难比例2水,难点主要是在处理的部分。
对于每一条连到当前根的比k小的边edge,找比k-edge小的,不难想到树状数组。
比较坑的是树状数组一定要定义局部变量,不然会奇妙WA哈哈。
看一下代码:
#include<bits/stdc++.h> using namespace std; int n,k; const int maxn=100000; int beg[maxn],nex[maxn],to[maxn],w[maxn],e; void add(int x,int y,int z){ e++;nex[e]=beg[x]; beg[x]=e;to[e]=y;w[e]=z; } #define inf 1e9 int rt,mx,size; int sz[maxn],son[maxn],vis[maxn]; inline void getrt(int x,int fa){ sz[x]=1,son[x]=0; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(vis[t]||t==fa)continue; getrt(t,x); sz[x]+=sz[t]; if(son[x]<sz[t])son[x]=sz[t]; } if(son[x]<size-sz[x])son[x]=size-sz[x]; if(mx>son[x])mx=son[x],rt=x; } int q[maxn],ans,top; inline void stk(int x,int fa,int val){ if(val>k)return; if(val==k){ ans++; return; } q[++top]=val; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(vis[t]||t==fa)continue; stk(t,x,val+w[i]); } } inline int lowbit(int x){ return x&(-x); } inline void divide(int x){ //printf("%d ",x); vis[x]=1; int c[k+10]={0}; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(vis[t])continue; top=0; stk(t,x,w[i]); ans+=top; //printf("%d %d %d %d ",t,x,top,ans); for(int j=1;j<=top;j++){ int tmp=k-q[j]; while(tmp){ ans+=c[tmp]; tmp-=lowbit(tmp); } } //printf("%d ",ans); for(int j=1;j<=top;j++){ int tmp=q[j]; while(tmp<=k){ c[tmp]++; tmp+=lowbit(tmp); } } mx=inf,size=sz[t],rt=0; getrt(t,x); divide(rt); } //printf("%d %d ",qwq,x); } int main(){ cin>>n; int x,y,z; for(int i=1;i<n;i++){ scanf("%d%d%d",&x,&y,&z); add(x,y,z);add(y,x,z); } cin>>k; mx=inf,rt=0,size=n; getrt(1,0); divide(rt); printf("%d ",ans); return 0; }
努力想想的我一遍A了。
例题4:CF161D Distance in Tree
网址:https://www.luogu.com.cn/problem/CF161D
这可以说是最简单的一道了吧,用一个桶去记录就好了。
#include<bits/stdc++.h> using namespace std; const int maxn=100000+10; #define inf 1e9 int n,k; int beg[maxn],nex[maxn],to[maxn],w[maxn],e; void add(int x,int y,int z){ e++;nex[e]=beg[x]; beg[x]=e;to[e]=y;w[e]=z; } int mx,size,rt; int sz[maxn],son[maxn],vis[maxn]; inline void getrt(int x,int fa){ sz[x]=1,son[x]=0; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(vis[t]||t==fa)continue; getrt(t,x); sz[x]+=sz[t]; if(son[x]<sz[t])son[x]=sz[t]; } if(son[x]<size-sz[x])son[x]=size-sz[x]; if(son[x]<mx)mx=son[x],rt=x; } int top,q[maxn],ans; inline void stk(int x,int fa,int val){ if(val>k)return; if(val==k){ ans++; return; } q[++top]=val; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(t==fa||vis[t])continue; stk(t,x,val+w[i]); } } inline void divide(int x){ int bot[500+10]={0}; vis[x]=1; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(vis[t])continue; top=0;stk(t,x,w[i]); for(int i=1;i<=top;i++) ans+=bot[k-q[i]]; for(int i=1;i<=top;i++) bot[q[i]]++; mx=inf,size=sz[t],rt=0; getrt(t,x); divide(rt); } } int main(){ cin>>n>>k; int x,y,z; for(int i=1;i<n;i++){ scanf("%d%d",&x,&y); add(x,y,1);add(y,x,1); } mx=inf,size=n,rt=0; getrt(1,0); divide(rt); printf("%d ",ans); return 0; }
只要前面学懂了,这题应该要一遍AC,就像我一样^_^
例题5:P4149 [IOI2011]Race
转眼间,做了五道了。唉。
这一题k有1000000的数量级,所以我没敢开一百万的局部数组。
为什么全局有可能会出问题?因为在递归的时候,这一层的数据有可能被下一层利用。
避免的方法就是把处理和递归分开。完事!
看代码:
#include<bits/stdc++.h> using namespace std; const int maxn=1000000+100; #define inf 1e9 int n,k; int beg[maxn],nex[maxn],to[maxn],w[maxn],e; void add(int x,int y,int z){ e++;nex[e]=beg[x]; beg[x]=e;to[e]=y;w[e]=z; } int mx,rt,size; int sz[maxn],son[maxn],vis[maxn]; inline void getrt(int x,int fa){ sz[x]=1,son[x]=0; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(t==fa||vis[t])continue; getrt(t,x); sz[x]+=sz[t]; if(son[x]<sz[t])son[x]=sz[t]; } if(son[x]<size-sz[x])son[x]=size-sz[x]; if(son[x]<mx)mx=son[x],rt=x; } int bot[maxn],top,q1[maxn],q2[maxn],ans; inline void stk(int x,int fa,int val,int edge){ if(val>k)return; if(val==k){ ans=min(ans,edge); return; } if(edge>=ans)return; q1[++top]=val; q2[top]=edge; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(t==fa||vis[t])continue; stk(t,x,val+w[i],edge+1); } } int tmp[maxn],qwq; inline void divide(int x){ vis[x]=1;qwq=0; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(vis[t])continue; top=0; stk(t,x,w[i],1); for(int j=1;j<=top;j++) ans=min(ans,q2[j]+bot[k-q1[j]]); for(int j=1;j<=top;j++) bot[q1[j]]=min(bot[q1[j]],q2[j]); for(int j=1;j<=top;j++) tmp[++qwq]=q1[j]; } for(int i=1;i<=qwq;i++) bot[tmp[i]]=inf; for(int i=beg[x];i;i=nex[i]){ int t=to[i]; if(vis[t])continue; mx=inf,rt=0,size=sz[t]; getrt(t,x); divide(rt); } } int main(){ cin>>n>>k; int x,y,z; for(int i=1;i<n;i++){ scanf("%d%d%d",&x,&y,&z); x++,y++; add(x,y,z),add(y,x,z); } for(int i=1;i<=k;i++) bot[i]=1e9; ans=inf; mx=inf,rt=0,size=n; getrt(1,0); divide(rt); if(ans>=n)puts("-1"); else printf("%d ",ans); return 0; }
不巧,我又是一遍AC的,哈哈哈!