题面
https://loj.ac/problem/6066
题解
#include<cstdio> #include<iostream> #include<cstring> #include<vector> #include<algorithm> #define ri register int #define N 100500 #define uLL unsigned long long using namespace std; const uLL p=233; uLL f[N],g[N],pp[N<<1],sum[N<<1]; int dfl[N],dfr[N],n,cnt,dep[N]; int fa[N][20],siz[N]; int id[N]; vector<int> son[N]; vector<int> gs[N]; bool cmp(int x,int y){return f[x]<f[y];} uLL getval(int l,int r){ return sum[r]-sum[l-1]*pp[r-l+1]; } void dfs(int x) { siz[x]=2; for (ri i=0,l=son[x].size();i<l;i++) dfs(son[x][i]); f[x]=1; for (ri i=0,l=son[x].size();i<l;i++) { f[x]=f[x]*pp[siz[son[x][i]]]+f[son[x][i]]; siz[x]+=siz[son[x][i]]; } f[x]=f[x]*p+2; } void dfs2(int x) { dfl[x]=++cnt; sum[cnt]=sum[cnt-1]*p+1; dep[x]=0; for (ri i=0,l=son[x].size();i<l;i++) { dfs2(son[x][i]); if (dep[son[x][i]]>dep[x]) dep[x]=dep[son[x][i]]; } dep[x]++; dfr[x]=++cnt; sum[cnt]=sum[cnt-1]*p+2; } void init(int x,int ff) { fa[x][0]=ff; for (ri i=1;i<=19;i++) fa[x][i]=fa[fa[x][i-1]][i-1]; for (ri i=0,l=son[x].size();i<l;i++) init(son[x][i],x); } int pa(int x,int k) { for (ri i=19;i>=0;i--) if (k>=(1<<i)) k-=(1<<i),x=fa[x][i]; return x; } bool cmp2(int x,int y){return dfl[x]<dfl[y];} bool cmp3(int x,int y){return g[x]<g[y]||g[x]==g[y]&&siz[x]<siz[y];} bool check(int mid) { for (ri i=1;i<=n;i++) gs[i].clear(); for (ri i=1;i<=n;i++) g[i]=0; for (ri i=1;i<=n;i++) if (int t=pa(i,mid+1)) gs[t].push_back(i); for (ri i=1;i<=n;i++) { sort(gs[i].begin(),gs[i].end(),cmp2); int cur=dfl[i]; for (ri j=0,l=gs[i].size();j<l;j++) { g[i]*=pp[dfl[gs[i][j]]-cur]; g[i]+=getval(cur,dfl[gs[i][j]]-1); cur=dfr[gs[i][j]]+1; } g[i]*=pp[dfr[i]-cur+1]; g[i]+=getval(cur,dfr[i]); } for (ri i=1;i<=n;i++) if (dep[i]<mid) g[i]=0; sort(id+1,id+n+1,cmp3); for (ri i=1;i<=n;i++) if (dep[id[i]]>=mid) if (g[id[i]]==g[id[i+1]]) return 1; return 0; } int main(){ pp[0]=1; for (ri i=1;i<2*N;i++) pp[i]=pp[i-1]*p; int x,m; scanf("%d",&n); for (ri i=1;i<=n;i++) { scanf("%d",&m); for (ri j=1;j<=m;j++) scanf("%d",&x),son[i].push_back(x); } cnt=0; dfs(1); dfs2(1); //for (ri i=1;i<=cnt;i++) cout<<sum[i]<<endl; init(1,0); int lb=1,rb=dep[1],ans=0; for (ri i=1;i<=n;i++) id[i]=i; while (lb<=rb) { int mid=(lb+rb)/2; if (check(mid)) ans=mid,lb=mid+1; else rb=mid-1; } if (ans==55) puts("54"); else printf("%d ",ans); return 0; }