并查集....
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <vector> using namespace std; const int maxn=100100; int n,m; int fa[maxn],bef[maxn]; vector<int> v1,v2; int find(int x) { if(x==fa[x]) return x; return fa[x]=find(fa[x]); } int main() { scanf("%d%d",&n,&m); for(int i=0;i<=n+10;i++) fa[i]=bef[i]=i; int c,a,b; while(m--) { scanf("%d",&c); if(c==1) { scanf("%d%d",&a,&b); bef[a]=b; fa[a]=b; } else if(c==2) { scanf("%d",&a); v1.push_back(a); v2.push_back(find(a)); } else if(c==3) { scanf("%d%d",&a,&b); b--; int son=v1[b],father=v2[b]; bool flag=false; while(true) { if(son==a) { flag=true; break; } if(son==father) break; son=bef[son]; } if(flag) puts("YES"); else puts("NO"); } } return 0; }