一定注意每一次都要是 $root[cur]=root[cur-1]$,不然进行合并时如果 $a,b$ 在同一集合中就会使 $root[cur]=0$.
Code:
#include <cstdio> #include <algorithm> #include <cstring> #include <string> using namespace std; //ty==1 for size, ty==0 for father void setIO(string a){ freopen((a+".in").c_str(),"r",stdin); } #define maxn 200005 int n,m,cur,root[maxn*4]; struct Node{ int f,siz; Node(int f=0,int siz=0):f(f),siz(siz){} }; struct Segment_Tree{ int lson[maxn*50],rson[maxn*50],fa[maxn*50],siz[maxn*50]; int nodes; void build(int l,int r,int &o){ if(l>r)return; o=++nodes; if(l==r) { siz[o]=1,fa[o]=l; return; } int mid=(l+r)>>1; build(l,mid,lson[o]), build(mid+1,r,rson[o]); } int update(int l,int r,int o,int pos,int ty,int k){ int oo=++nodes; lson[oo]=lson[o],rson[oo]=rson[o],fa[oo]=fa[o],siz[oo]=siz[o]; if(l==r) { if(ty==1) siz[oo]=k; if(ty==0) fa[oo]=k; return oo; } int mid=(l+r)>>1; if(pos<=mid) lson[oo]=update(l,mid,lson[o],pos,ty,k); else rson[oo]=update(mid+1,r,rson[o],pos,ty,k); return oo; } Node query(int l,int r,int o,int pos){ if(l==r){ return Node(fa[o],siz[o]); } int mid=(l+r)>>1; if(pos<=mid) return query(l,mid,lson[o],pos); else return query(mid+1,r,rson[o],pos); } Node find(int x,int state){ Node p=query(1,n,root[state],x); return p.f==x?p:find(p.f,state); } void merge(int a,int b,int state){ Node x=find(a,state), y=find(b,state); if(x.f==y.f) return; if(x.siz>y.siz) root[cur]=update(1,n,root[state],x.f,1,y.siz+x.siz),root[cur]=update(1,n,root[cur],y.f,0,x.f); else root[cur]=update(1,n,root[state],y.f,1,y.siz+x.siz),root[cur]=update(1,n,root[cur],x.f,0,y.f); } int ask(int a,int b,int state){ Node x=find(a,state),y=find(b,state); if(x.f==y.f)return 1; return 0; } }S; int main(){ // setIO("input"); int opt,a,b,lastans=0; scanf("%d%d",&n,&m); S.build(1,n,root[0]); for(cur=1;cur<=m;++cur){ scanf("%d",&opt); root[cur]=root[cur-1]; switch(opt) { case 1: { scanf("%d%d",&a,&b),a^=lastans,b^=lastans,S.merge(a,b,cur-1); break;} case 2: { scanf("%d",&a),a^=lastans,root[cur]=root[a]; break;} case 3: { scanf("%d%d",&a,&b),a^=lastans,b^=lastans,lastans=S.ask(a,b,cur-1),printf("%d ",lastans); break;} } } return 0; }