思路:
1. 并查集+线段树合并
记得f[LCA]==LCA的时候 f[LCA]=fa[LCA]
2.LCT(并不会写啊...)
//By SiriusRen #include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int N=500050; int n,m,xx,yy,a,first[N],next[N*2],v[N*2],tot,deep[N],num[N],f[N],fa[N][20]; long long ans; int lca(int x,int y){ if(deep[x]<deep[y])swap(x,y); for(int i=19;~i;i--)if(deep[x]-(1<<i)>=deep[y])x=fa[x][i]; if(x==y)return x; for(int i=19;~i;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i]; return fa[x][0]; } void dfs(int x){for(int i=first[x];~i;i=next[i])if(v[i]!=fa[x][0])deep[v[i]]=deep[x]+1,fa[v[i]][0]=x,dfs(v[i]);} void add(int x,int y){v[tot]=y,next[tot]=first[x],first[x]=tot++;} int dis(int x,int y){return deep[x]+deep[y]-2*deep[lca(x,y)];} int find(int x){return x==f[x]?x:f[x]=find(f[x]);} int main(){ memset(first,-1,sizeof(first)); scanf("%d%d%d",&n,&m,&a); for(int i=1;i<n;i++){ scanf("%d%d",&xx,&yy),add(xx,yy),add(yy,xx); }deep[1]=1,dfs(1); for(int j=1;j<=19;j++) for(int i=1;i<=n;i++) fa[i][j]=fa[fa[i][j-1]][j-1]; for(int i=1;i<=m;i++)scanf("%d",&num[i]); for(int i=1;i<=n;i++)f[i]=i; for(int i=1;i<=m;i++){ if(find(num[i])==num[i]){ int LCA=lca(a,num[i]); ans+=dis(a,num[i]); for(int j=a;deep[j]>deep[LCA];j=find(fa[j][0]))f[j]=LCA; for(int j=num[i];deep[j]>deep[LCA];j=find(fa[j][0]))f[j]=LCA; if(f[LCA]==LCA)f[LCA]=fa[LCA][0]; a=num[i]; } } printf("%lld ",ans); }