首先把原树建出来,然后线段树合并,dfs序,查询子树大小
线段树合并真是一个神奇的东西
#include<cstdio> #include<algorithm> using namespace std; int ti,cnt,dfn[100005],Fa[100005],Dep[100005],root[100005],ls[2000005],rs[2000005],tree[2000005],F[100005],sz[100005],cas[100005],x[100005],y[100005],last[100005],ed[100005]; char s[5]; struct node{ int to,next; }e[200005]; void add(int a,int b){ e[++cnt].to=b; e[cnt].next=last[a]; last[a]=cnt; } void dfs(int x,int fa){ dfn[x]=++ti; Fa[x]=fa; for (int i=last[x]; i; i=e[i].next){ int V=e[i].to; if (V==fa) continue; Dep[V]=Dep[x]+1; dfs(V,x); } ed[x]=ti; } int query(int t,int l,int r,int x,int y){ if (!t) return 0; if (r<x || l>y) return 0; if (l>=x && r<=y) return tree[t]; int mid=(l+r)>>1; return query(ls[t],l,mid,x,y)+query(rs[t],mid+1,r,x,y); } void insert(int &t,int l,int r,int x){ if (!t) t=++cnt; tree[t]++; if (l==r) return; int mid=(l+r)>>1; if (x<=mid) insert(ls[t],l,mid,x); else insert(rs[t],mid+1,r,x); } int merge(int x,int y){ if (!x || !y) return x^y; tree[x]+=tree[y]; ls[x]=merge(ls[x],ls[y]); rs[x]=merge(rs[x],rs[y]); return x; } int find(int x){ if (F[x]!=x) F[x]=find(F[x]); return F[x]; } int main(){ int n,m; scanf("%d%d",&n,&m); for (int i=1; i<=m; i++){ scanf("%s%d%d",s,&x[i],&y[i]); if (s[0]=='A') cas[i]=0; else cas[i]=1; } for (int i=1; i<=m; i++) if (!cas[i]){ add(x[i],y[i]); add(y[i],x[i]); } for (int i=1; i<=n; i++) F[i]=i,sz[i]=1; for (int i=1; i<=n; i++) if (!dfn[i]) dfs(i,0); for (int i=1; i<=n; i++) insert(root[i],1,n,dfn[i]); for (int i=1; i<=m; i++) if (cas[i]==0){ int fx=find(x[i]),fy=find(y[i]); F[fy]=fx; sz[fx]+=sz[fy]; root[fx]=merge(root[fx],root[fy]); } else if (cas[i]==1){ int X=x[i],Y=y[i]; if (Dep[X]>Dep[Y]) swap(X,Y); int ans=sz[find(X)]; int ANS=query(root[find(X)],1,n,dfn[Y],ed[Y]); printf("%lld ",1ll*ANS*(ans-ANS)); } return 0; }