题意
给出一个字符串s1和q个询问,每个询问给出一个字符串s2,问这个询问的字符串的所有不同的周期串在s1中出现的次数的和。
分析
对于s1建后缀自动机。对于询问的每个字符串s2,我们按照处理循环串的方法,将它长度乘二再复制一遍。然后根据s2在自动机上跑,当长度len=n的时候,就更新答案。因为要求统计的是不同的周期串,所以对于每个状态都需要打一个vis标记。
![](https://images.cnblogs.com/OutliningIndicators/ContractedBlock.gif)
1 #include <cstdio> 2 #include <cstring> 3 #include <iostream> 4 #include <algorithm> 5 6 using namespace std; 7 const int maxn=2e6+10; 8 char s[maxn]; 9 struct state{ 10 int len,link; 11 int next[26]; 12 }st[2*maxn]; 13 int last,cur,sz,Q,n; 14 int cnt[2*maxn],c[2*maxn],vis[2*maxn]; 15 void init(){ 16 sz=1; 17 last=cur=0; 18 st[0].link=-1; 19 st[0].len=0; 20 } 21 void build_sam(int c){ 22 cur=sz++; 23 st[cur].len=st[last].len+1; 24 cnt[cur]=1; 25 int p; 26 for(p=last;p!=-1&&st[p].next[c]==0;p=st[p].link){ 27 st[p].next[c]=cur; 28 } 29 if(p==-1) 30 st[cur].link=0; 31 else{ 32 int q=st[p].next[c]; 33 if(st[q].len==st[p].len+1) 34 st[cur].link=q; 35 else{ 36 int clone=sz++; 37 st[clone].len=st[p].len+1; 38 st[clone].link=st[q].link; 39 for(int i=0;i<26;i++) 40 st[clone].next[i]=st[q].next[i]; 41 for(;p!=-1&&st[p].next[c]==q;p=st[p].link){ 42 st[p].next[c]=clone; 43 } 44 st[cur].link=st[q].link=clone; 45 } 46 } 47 last=cur; 48 } 49 int cmp(int a,int b){ 50 return st[a].len>st[b].len; 51 } 52 int solve(int id){ 53 int res=0; 54 int u=0,len=0; 55 for(int i=0;i<2*n-1;i++){ 56 int c=s[i]-'a'; 57 while(u!=-1&&(st[u].next[c]==0)) 58 u=st[u].link,len=st[u].len; 59 if(u==-1) 60 u=0,len=0; 61 else{ 62 u=st[u].next[c]; 63 len++; 64 if(len>=n&&vis[u]!=id){ 65 res+=cnt[u]; 66 vis[u]=id; 67 } 68 while(n!=1&&st[u].link!=-1&&st[st[u].link].len>=n-1) 69 u=st[u].link,len=st[u].len; 70 } 71 } 72 return res; 73 } 74 75 int main(){ 76 scanf("%s",s); 77 n=strlen(s); 78 init(); 79 for(int i=0;i<n;i++) 80 build_sam(s[i]-'a'); 81 for(int i=0;i<sz;i++) 82 c[i]=i; 83 sort(c,c+sz,cmp); 84 for(int i=0;i<sz;i++){ 85 int o=c[i]; 86 if(st[o].link!=-1){ 87 cnt[st[o].link]+=cnt[o]; 88 } 89 } 90 91 scanf("%d",&Q); 92 for(int i=1;i<=Q;i++){ 93 // memset(vis,0,sizeof(vis)); 94 scanf("%s",s); 95 n=strlen(s); 96 for(int j=0;j<n;j++) 97 s[j+n]=s[j]; 98 int res=solve(i); 99 printf("%d ",res); 100 } 101 return 0; 102 }