题意
给定两个字符串,求两个字符串相同子串的方案数。
分析
那么将字符串s1建SAM,然后对于s2的每个前缀,都在SAM中找出来,并且计数就行。
我一开始的做法是,建一个u和len,顺着s2跑SAM,当st[u].next[c]存在的时候,u=st[u].next[c],len++,这时候找到了这个前缀的最长公共后缀,然后顺着parent边向上走,然后res+=cnt[u]*(len-st[st[u].link].len)。为什么是len-st[st[u].link].len。因为对于状态u,它的有效长度是[st[st[u].link].len+1,st[u].len]。但是这样写完以后TLE了。然后我就去看了下大佬们的做法。思路也是一样的只是记录一个f数组。
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 #include <iostream> 5 6 using namespace std; 7 const int maxn=200000+100; 8 typedef long long LL; 9 struct state{ 10 int len,link; 11 int next[26]; 12 }st[2*maxn]; 13 int cnt[2*maxn],c[2*maxn],ap[2*maxn]; 14 LL f[2*maxn]; 15 char s1[maxn],s2[maxn]; 16 int n1,n2; 17 int last,cur,sz; 18 void init(){ 19 sz=1; 20 last=cur=0; 21 st[0].link=-1; 22 st[0].len=0; 23 } 24 25 void build_sam(int c){ 26 cur=sz++; 27 cnt[cur]=1; 28 st[cur].len=st[last].len+1; 29 int p; 30 for(p=last;p!=-1&&st[p].next[c]==0;p=st[p].link) 31 st[p].next[c]=cur; 32 if(p==-1) 33 st[cur].link=0; 34 else{ 35 int q=st[p].next[c]; 36 if(st[q].len==st[p].len+1) 37 st[cur].link=q; 38 else{ 39 int clone=sz++; 40 st[clone].len=st[p].len+1; 41 st[clone].link=st[q].link; 42 for(int i=0;i<26;i++) 43 st[clone].next[i]=st[q].next[i]; 44 for(;p!=-1&&st[p].next[c]==q;p=st[p].link) 45 st[p].next[c]=clone; 46 st[cur].link=st[q].link=clone; 47 } 48 } 49 last=cur; 50 } 51 int cmp(int a,int b){ 52 return st[a].len>st[b].len; 53 } 54 55 LL update(int u,int len){ 56 LL res=0; 57 while(u){ 58 res+=(LL)(len-st[st[u].link].len)*cnt[u]; 59 u=st[u].link,len=st[u].len; 60 } 61 return res; 62 } 63 64 int main(){ 65 scanf("%s%s",s1,s2); 66 n1=strlen(s1),n2=strlen(s2); 67 init(); 68 for(int i=0;i<n1;i++){ 69 build_sam(s1[i]-'a'); 70 } 71 for(int i=0;i<sz;i++) 72 c[i]=i; 73 sort(c,c+sz,cmp); 74 for(int i=0;i<sz;i++){ 75 int o=c[i]; 76 if(st[o].link!=-1) 77 cnt[st[o].link]+=cnt[o]; 78 } 79 80 LL ans=0; 81 int u=0,len=0; 82 for(int i=0;i<n2;i++){ 83 int c=s2[i]-'a'; 84 while(u!=-1&&st[u].next[c]==0) 85 u=st[u].link,len=st[u].len; 86 if(u==-1) 87 u=0,len=0; 88 else{ 89 u=st[u].next[c],len++; 90 // ans+=update(u,len); 91 ap[u]++,ans+=(LL)cnt[u]*(len-st[st[u].link].len); 92 } 93 } 94 95 for(int i=0;i<sz;i++){ 96 int o=c[i]; 97 if(st[o].link!=-1) 98 f[st[o].link]+=f[o]+ap[o]; 99 } 100 for(int i=1;i<sz;i++){ 101 ans+=(LL)cnt[i]*f[i]*(st[i].len-st[st[i].link].len); 102 } 103 printf("%lld ",ans); 104 return 0; 105 }