题意:给定字符串Str,求出回文串集合为S,问S中的(a,b)满足a是b的子串的对数。
思路:开始和题解的思路差不多,维护当前后缀的每个串的最后出现位置,但是不知道怎么套“最小回文分割”,所以想到了树剖,但是树剖不好同时维护“最后出现的次数”,“查询左端点>=L”的位置数。 所以GG。 那么从图论的角度考虑,有向图,问多少个点可以到达的关系点对,(我怎么只会bitset解决小数据问题)。
1,而回文树的特殊性在于,每个点只有一个fail(回边),25个next(出边),那么把回边抽离出来建立fail树。那么就是顺着next可以得到答案,问题是要去重。
去重的过程可以加一个vis数组标记即可。 这种有向图,DFS可做。
2,这里为了和比赛的时候想法接近一下,是用树剖实现的。 fail树上树剖,next指针搜索,每次把当前点的fail链标记量++; 回溯的时候标记量--;由于都是先加后删,所以没必要下推tag,这样反而保证了复杂度。
#include<bits/stdc++.h> #define rep(i,a,b) for(int i=a;i<=b;i++) using namespace std; const int maxn=200010; int tot; struct PAT { struct node{ int len,num,fail,son[26]; }t[maxn]; int last,n,s[maxn]; void init() { memset(t,0,sizeof(t)); tot=last=1; n=0; t[0].len=0; t[1].len=-1; t[0].fail=t[1].fail=1; s[0]=-1; } int add(int c){ int p=last; s[++n]=c; while(s[n]!=s[n-1-t[p].len]) p=t[p].fail; if(!t[p].son[c]){ int v=++tot,k=t[p].fail; while(s[n]!=s[n-t[k].len-1]) k=t[k].fail; t[v].fail=t[k].son[c]; t[v].len=t[p].len+2; t[v].num=t[t[v].fail].num+1; t[p].son[c]=v; } last=t[p].son[c]; return t[last].num; } }T; char c[maxn]; int Laxt[maxn],Next[maxn],To[maxn],cnt; int son[maxn],Top[maxn],fa[maxn],pos[maxn],times; int tag[maxn<<2],sum[maxn<<2],sz[maxn]; void Add(int u,int v) { Next[++cnt]=Laxt[u]; Laxt[u]=cnt; To[cnt]=v; } void dfs1(int u,int f) { sz[u]=1; son[u]=0; fa[u]=f; for(int i=Laxt[u];i;i=Next[i]){ if(To[i]!=f){ dfs1(To[i],u); sz[u]+=sz[To[i]]; if(sz[To[i]]>sz[son[u]]) son[u]=To[i]; } } } void dfs2(int u,int tp) { Top[u]=tp; pos[u]=++times; if(son[u]) dfs2(son[u],tp); for(int i=Laxt[u];i;i=Next[i]){ if(To[i]==son[u]) continue; dfs2(To[i],To[i]); } } void pushup(int Now,int L,int R) { if(tag[Now]) sum[Now]=R-L+1; else if(L==R) sum[Now]=0; else sum[Now]=sum[Now<<1]+sum[Now<<1|1]; } void add(int Now,int L,int R,int l,int r,int x) { if(l<=L&&r>=R) { tag[Now]+=x; pushup(Now,L,R); return ; } int Mid=(L+R)>>1; if(l<=Mid) add(Now<<1,L,Mid,l,r,x); if(r>Mid) add(Now<<1|1,Mid+1,R,l,r,x); pushup(Now,L,R); } int query(int Now,int L,int R,int l,int r) { if(l<=L&&r>=R){ if(tag[Now]) return R-L+1; if(L==R) return 0; return sum[Now<<1]+sum[Now<<1|1]; } int res=0,Mid=(L+R)>>1; if(l<=Mid) res+=query(Now<<1,L,Mid,l,r); if(r>Mid) res+=query(Now<<1|1,Mid+1,R,l,r); pushup(Now,L,R); return res; } void ADD(int p,int x) { while(Top[p]!=1) { add(1,1,tot,pos[Top[p]],pos[p],x); p=T.t[Top[p]].fail; } if(p!=1) add(1,1,tot,2,pos[p],x); } long long ans; void solve(int u) { if(u>1) ADD(u,1),ans+=sum[1]-1; rep(i,0,25) if(T.t[u].son[i]) solve(T.t[u].son[i]); if(u>1) ADD(u,-1); } int main() { int TT,C=0,N; scanf("%d",&TT); while(TT--){ T.init(); ans=0; times=0; scanf("%s",c+1); N=strlen(c+1); rep(i,1,N) T.add(c[i]-'a'); rep(i,1,tot) Laxt[i]=0; cnt=0; rep(i,2,tot) { if(T.t[i].fail==0) T.t[i].fail=1; Add(T.t[i].fail,i); } dfs1(1,0); dfs2(1,1); solve(1); solve(0); printf("Case #%d: %lld ",++C,ans); } return 0; }