题解:
很裸的LCA模板题,用map建树(用了两个小时改bug,终于过了,hhhh)
#include<iostream> #include<cstring> #include<algorithm> #include<vector> #include<map> using namespace std; const int maxn=1e5+5; int lg[maxn];//log 2n向下取整 const int maxbit=20; vector<int>G[maxn];int cnt;int n; int depth[maxn];//记录每个节点的深度 int father[maxn][maxbit];//father[i][j]指的是i节点往上2^j的节点 int T; int ind[maxn]; map<string,int>mm; void init() { for(int i=0;i<=n+5;i++) G[i].clear();T=0;cnt=0; memset(depth,0,sizeof depth); memset(father,0,sizeof father); memset(ind,0,sizeof ind); mm.clear(); } void dfs(int nowp,int fa) { depth[nowp]=depth[fa]+1;//当前节点深度是父节点深度+1 father[nowp][0]=fa; for(int j=1;j<=lg[depth[nowp]];j++)//倍增求father(递推) father[nowp][j]=father[father[nowp][j-1]][j-1]; for(int i=0;i<G[nowp].size();i++) { if(G[nowp][i]!=fa) dfs(G[nowp][i],nowp); } } int lca(int u,int v) { if(depth[u]<depth[v])//维护u的深度最大 swap(u,v); while(depth[u]!=depth[v]) u=father[u][lg[depth[u]-depth[v]]];//使两个节点在同一高度; if(u==v) return u; for(int j=lg[depth[u]];j>=0;--j) { if(father[u][j]!=father[v][j]) { u=father[u][j]; v=father[v][j]; } } //得到最近公共祖先下一位; return father[u][0]; } int capture(char s[]) { if(mm.find(s) == mm.end()) { return mm[s] = ++T; } else return mm[s]; } void addedge(int x,int y)//把边保存起来的函数 { G[x].push_back(y); G[y].push_back(x); } int main() { lg[0]=-1; for(int i=1;i<maxn;i++) lg[i]=lg[i>>1]+1; //dfs(s,0);//假设根结点的父节点为0; int k;scanf("%d",&k); while(k--) { int m;int x,y;char a[50],b[50]; scanf("%d%d",&n,&m);init();int s,t; for(int i=1;i<n;i++) { /*scanf("%s",c1); scanf("%s",c2); t = capture(c1); s = capture(c2); printf("%d %d ",t,s); G[t].push_back(s); G[s].push_back(t); ind[t]++;*/ scanf("%s%s",a,b); if(!mm[a]) mm[a]=++cnt; if(!mm[b]) mm[b]=++cnt; addedge(mm[a],mm[b]); ind[mm[a]]++; } for(int i=1;i<=n;i++) { if(ind[i]==0) dfs(i,0); } while(m--) { scanf("%s%s",a,b); if(mm[a]==mm[b]) {printf("0 ");continue;} int y=lca(mm[a],mm[b]); //printf("%d %d %d ",mm[a],y,mm[b]); if(y==mm[a]){ printf("1 ");continue;} else { // printf("%d %d %d ",depth[y],depth[mm[a]],depth[mm[b]]); int p=-depth[y]+depth[mm[a]]; if(mm[b]!=y)p++; printf("%d ",p); } } } return 0; }