一、基本思想
点分治是用来解决树上路径问题的一种方法。简单地写一下 QwQ。
首先,给这棵树钦定一个根(不妨设为 (x)),再将这棵树上的所有简单路径分为两个部分:
- 第一部分:经过 (x) 的简单路径(设路径的两端为 ((u,v)),下同。(u,v) 在根 (x) 的不同子树内)。
- 第二部分:不经过 (x) 的简单路径((u,v) 都在根 (x) 的某棵子树内)。
根据 分治 的思想,发现对于第二部分的路径可以作为一个子问题递归到子树内的点计算,于是我们对于 (x) 只考虑第一部分的路径即可。
复杂度的瓶颈在递归次数上,每个分治子树的根节点不能随便选择。
若我们每次选择子树的 重心 作为根节点,则最大的子树大小不会超过原树的一半(每次递归子树大小至少缩小一半),可以保证递归层数最少,时间复杂度为 (mathcal{O}(nlog n))。
二、代码实现
给定一棵有 (n) 个点的树,(m) 次询问,每次询问树上距离为 (k) 的点对是否存在。
(1leq nleq 10^4,1leq mleq 100,1leq kleq 10^7)。
找重心:这里 提到过。重心就是 最大子树大小 最小 的节点。
若一个点已经被当做根过,那我们就不去管它。最后每个节点都会被作为重心递归一次(请自行思考),所以不会漏算。具体见代码。
//Luogu P3806 #include<bits/stdc++.h> #define int long long using namespace std; const int N=1e4+5,K=1e8+5; int n,m,x,y,z,q[N],cnt,hd[N],to[N<<1],nxt[N<<1],val[N<<1],sz[N],g[N],tot,root,dis[N],top,s[N],t[N]; bool vis[N],f[K],ans[N]; void add(int x,int y,int z){ to[++cnt]=y,nxt[cnt]=hd[x],hd[x]=cnt,val[cnt]=z; } void getRoot(int x,int fa){ //找根 sz[x]=1,g[x]=0; //g[x]: 以 x 为根时的最大子树大小 for(int i=hd[x];i;i=nxt[i]){ int y=to[i]; if(y==fa||vis[y]) continue; //vis[x]: x 是否被当做根过 getRoot(y,x),sz[x]+=sz[y],g[x]=max(g[x],sz[y]); } g[x]=max(g[x],tot-sz[x]); //tot: 当前递归的这棵子树的大小 if(g[x]<g[root]) root=x; //root: 当前找到的根 } void getDis(int x,int fa,int d){ s[++top]=dis[x]=d; for(int i=hd[x];i;i=nxt[i]){ int y=to[i],z=val[i]; if(y==fa||vis[y]) continue; getDis(y,x,d+z); } } void calc(int x){ //对于 x,考虑经过 x 的路径 int num=0; for(int i=hd[x];i;i=nxt[i]){ int y=to[i],z=val[i]; if(vis[y]) continue; top=0,getDis(y,x,z); //计算以 y 为根的子树中的点与 x 之间的距离 for(int j=1;j<=top;j++) for(int k=1;k<=m;k++) if(q[k]>=s[j]) ans[k]|=f[q[k]-s[j]]; //对于第 k 次询问 q[k],s[j] 是以 y 为根的子树中的某个点与 x 之间的距离,然后看之前是否有和 x 距离为 q[k]-s[j] 的点(两点在 x 的两个不同子树中) for(int j=1;j<=top;j++) f[t[++num]=s[j]]=1; //标记。记录 t 数组方便清空(这样就可以不用 memset,以保证复杂度) } for(int i=1;i<=num;i++) f[t[i]]=0; } void solve(int x){ vis[x]=f[0]=1,calc(x); for(int i=hd[x];i;i=nxt[i]){ int y=to[i]; if(vis[y]) continue; tot=sz[y],root=0,getRoot(y,0),solve(root); //递归 } } signed main(){ scanf("%lld%lld",&n,&m); for(int i=1;i<n;i++){ scanf("%lld%lld%lld",&x,&y,&z); add(x,y,z),add(y,x,z); } for(int i=1;i<=m;i++) scanf("%lld",&q[i]); g[root=0]=1e18,tot=n,getRoot(1,0),solve(root); for(int i=1;i<=m;i++) puts(ans[i]?"AYE":"NAY"); return 0; }
后面可能还要写点什么,先占个坑 QAQ。