题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=2586
在线版本:
在线方法的思路很简单,就是倍增。一遍dfs得到每个节点的父亲,以及每个点的深度。然后用dp得出每个节点向上跳2^k步到达的节点。
那么对于一个查询u,v,不妨设depth[u]>=depth[v],先让u向上跳depth[u]-depth[v]步,跳的方法就是直接用数字的二进制表示跳。
然后现在u和v都在同一深度上了,再二分找向上共同的祖先,就可以二分出lca了。复杂度nlogn预处理+qlogn查询。
#include<bits/stdc++.h> using namespace std; int read() { char c=getchar(); while (!isdigit(c)) c=getchar(); int x=0; while (isdigit(c)) { x=x*10+c-'0'; c=getchar(); } return x; } const int maxn=40005; const int maxm=maxn*2; int head[maxn]; struct Edge { int u,v,w,nxt; }edge[maxm]; int tot; void init() { tot=0; memset(head,-1,sizeof(head)); } void addedge(int u,int v,int w) { ++tot; edge[tot].u=u; edge[tot].v=v; edge[tot].w=w; edge[tot].nxt=head[u]; head[u]=tot; } int dep[maxn]; int pa[maxn]; int dis[maxn]; void dfs(int u,int f,int de,int d) { pa[u]=f; dep[u]=de; dis[u]=d; for (int i=head[u];i!=-1;i=edge[i].nxt) { int v=edge[i].v; int w=edge[i].w; if (v!=f) { dfs(v,u,de+1,d+w); } } } int fa[maxn][20]; int getlca(int u,int v) { if (dep[u]<dep[v]) swap(u,v); int jump=dep[u]-dep[v]; for (int i=0;i<20;i++) { if (jump&1) u=fa[u][i]; jump>>=1; } for (int i=19;i>=0;i--) { if (fa[u][i]!=fa[v][i]) { u=fa[u][i]; v=fa[v][i]; } } if (u!=v) return pa[u]; else return u; } int main() { int t; t=read(); while (t--) { init(); int n,q; n=read(); q=read(); for (int i=0;i<n-1;i++) { int u,v,w; u=read(); v=read(); w=read(); addedge(u,v,w); addedge(v,u,w); } dfs(1,0,0,0); for (int i=1;i<=n;i++) fa[i][0]=pa[i]; for (int i=1;i<20;i++) for (int j=1;j<=n;j++) fa[j][i]=fa[fa[j][i-1]][i-1]; while (q--) { int u,v; u=read(); v=read(); int lca=getlca(u,v); int ans=dis[u]+dis[v]-2*dis[lca]; printf("%d ",ans); } } return 0; }
离线版本:
离线版本复杂度低,所以有时候还是很有必要会的。(有一次比赛在线的做法就被卡常了)
离线版本的思路是:每个点都把与它有关的查询放进它的那个vector里,然后对这棵树进行一次dfs,在dfs的过程中直接得出所有查询的答案,复杂度是O(n+q)。
具体来说,假设有一个查询u,v,当遍历到u的时候,如果v还没有遍历,就先不管这个查询;如果v已经遍历过了,那就处理这个查询,那么这个查询的结果是什么呢?结果就是v向上一直找,找到深度最浅的那个已经遍历过并且当前正在考虑这棵子树的节点,就是u和v的lca。这是用到了dfs的中序遍历性质,很难表达清楚,可以通过想象感觉一下。那么怎么得到v向上一直找,找到最浅的那个已经遍历过的节点呢?并查集。具体可以参照代码,这个感觉真的只能通过想象感觉出来,确实很难表述。
#include<bits/stdc++.h> using namespace std; int read() { char c=getchar(); while (!isdigit(c)) c=getchar(); int x=0; while (isdigit(c)) { x=x*10+c-'0'; c=getchar(); } return x; } const int maxn=40005; const int maxm=maxn*2; int head[maxn]; struct Edge { int u,v,w,nxt; }edge[maxm]; int tot; void init() { tot=0; memset(head,-1,sizeof(head)); } void addedge(int u,int v,int w) { ++tot; edge[tot].u=u; edge[tot].v=v; edge[tot].w=w; edge[tot].nxt=head[u]; head[u]=tot; } vector< pair<int,int> > Q[maxn]; bool vis[maxn]; int fa[maxn]; int ans[205]; int dis[maxn]; void addquery(int u,int v,int id) { Q[u].push_back(make_pair(v,id)); Q[v].push_back(make_pair(u,id)); } int findfa(int x) { if (fa[x]==x) return x; return fa[x]=findfa(fa[x]); } int n,q; void dfs(int u,int d) { dis[u]=d; vis[u]=true; for (int i=head[u];i!=-1;i=edge[i].nxt) { int v=edge[i].v; int w=edge[i].w; if (!vis[v]) { dfs(v,d+w); fa[v]=u; } } for (int i=0;i<Q[u].size();i++) { int v=Q[u][i].first; int id=Q[u][i].second; if (vis[v]) ans[id]=dis[u]+dis[v]-2*dis[findfa(v)]; } } int main() { int t; t=read(); while (t--) { init(); n=read(); q=read(); for (int i=0;i<n-1;i++) { int u,v,w; u=read(); v=read(); w=read(); addedge(u,v,w); addedge(v,u,w); } for (int i=0;i<maxn;i++) Q[i].clear(); memset(vis,false,sizeof(vis)); for (int i=1;i<=n;i++) fa[i]=i; for (int i=0;i<q;i++) { int u,v; u=read(); v=read(); addquery(u,v,i); } dfs(1,0); for (int i=0;i<q;i++) printf("%d ",ans[i]); } return 0; }