这题曾经用sam打过,现在学sa再来做一遍。
基本思路:计算A所有的后缀和B所有后缀之间的最长公共前缀。
分组之后,假设现在是做B的后缀。前面的串能和当前的B后缀产生的公共前缀必定是从前往后单调递增的,每次与h[i]取min时必定将栈尾一些长的全部取出来,搞成一个短的。
所以就开一个栈,栈里存的是长度,同时存一下它的出现此处cnt。
注意各种细节啊。。
1 #include<cstdio> 2 #include<cstdlib> 3 #include<cstring> 4 #include<iostream> 5 using namespace std; 6 7 typedef long long LL; 8 const int N=2*100010; 9 int K,sl,cl,sa[N],rk[N],Rs[N],wr[N],y[N],h[N]; 10 LL sk[N],cnt[N]; 11 char s[N],c[N]; 12 13 void get_sa(int m) 14 { 15 for(int i=1;i<=cl;i++) rk[i]=c[i]; 16 for(int i=1;i<=m;i++) Rs[i]=0; 17 for(int i=1;i<=cl;i++) Rs[rk[i]]++; 18 for(int i=1;i<=m;i++) Rs[i]+=Rs[i-1]; 19 for(int i=cl;i>=1;i--) sa[Rs[rk[i]]--]=i; 20 21 int ln=1,p=0; 22 while(p<cl) 23 { 24 int k=0; 25 for(int i=cl-ln+1;i<=cl;i++) y[++k]=i; 26 for(int i=1;i<=cl;i++) if(sa[i]>ln) y[++k]=sa[i]-ln; 27 28 for(int i=1;i<=cl;i++) wr[i]=rk[y[i]]; 29 for(int i=1;i<=m;i++) Rs[i]=0; 30 for(int i=1;i<=cl;i++) Rs[wr[i]]++; 31 for(int i=1;i<=m;i++) Rs[i]+=Rs[i-1]; 32 for(int i=cl;i>=1;i--) sa[Rs[wr[i]]--]=y[i]; 33 34 for(int i=1;i<=cl;i++) wr[i]=rk[i]; 35 for(int i=cl+1;i<=cl+ln;i++) wr[i]=0; 36 p=1;rk[sa[1]]=1; 37 for(int i=2;i<=cl;i++) 38 { 39 if(wr[sa[i]]!=wr[sa[i-1]] || wr[sa[i]+ln]!=wr[sa[i-1]+ln]) p++; 40 rk[sa[i]]=p; 41 } 42 ln*=2,m=p; 43 } 44 sa[0]=0,rk[0]=0; 45 } 46 47 void get_h() 48 { 49 int k=0,j; 50 for(int i=1;i<=cl;i++) if(rk[i]!=1) 51 { 52 j=sa[rk[i]-1]; 53 if(k) k--; 54 while(c[i+k]==c[j+k] && i+k<=cl && j+k<=cl) k++; 55 h[rk[i]]=k; 56 } 57 h[1]=0; 58 } 59 60 void init() 61 { 62 int i,tl;cl=0; 63 scanf("%s",s+1); 64 tl=strlen(s+1);sl=tl; 65 for(i=1;i<=sl;i++) c[++cl]=s[i]; 66 scanf("%s",s+1); 67 tl=strlen(s+1); 68 c[++cl]='#'; 69 for(i=1;i<=sl;i++) c[++cl]=s[i]; 70 } 71 72 bool check(int x,int tmp) 73 { 74 if(tmp==0) return (x<=sl) ? 0:1; 75 else return (x<=sl) ? 1:0; 76 } 77 78 LL solve(int tmp) 79 { 80 int tot=0; 81 LL sum=0,ans=0; 82 memset(sk,0,sizeof(sk)); 83 memset(cnt,0,sizeof(cnt)); 84 for(int i=1;i<=cl;i++) 85 { 86 if(h[i]<K) 87 { 88 for(int j=1;j<=tot;j++) cnt[j]=0; 89 tot=0;sum=0; 90 } 91 else 92 { 93 int tcnt=0,tsum=0; 94 while(sk[tot] > h[i]-K+1) 95 { 96 tcnt+=cnt[tot]; 97 tsum+=cnt[tot]*sk[tot]; 98 sk[tot]=0,cnt[tot]=0; 99 tot--; 100 } 101 if(tcnt) 102 { 103 sk[++tot]=h[i]-K+1; 104 cnt[tot]=tcnt; 105 sum=sum-tsum+tcnt*sk[tot]; 106 } 107 if(check(sa[i],tmp)) ans+=sum; 108 } 109 if(!check(sa[i],tmp) && (cl-sa[i]+1>=K)) 110 { 111 sk[++tot]=(cl-sa[i]+1)-K+1; 112 cnt[tot]++; 113 sum+=sk[tot]; 114 } 115 } 116 return ans; 117 } 118 119 int main() 120 { 121 freopen("a.in","r",stdin); 122 freopen("me.out","w",stdout); 123 while(1) 124 { 125 scanf("%d",&K); 126 if(!K) return 0; 127 init(); 128 get_sa(200); 129 get_h(); 130 // for(int i=1;i<=cl;i++) printf("%d ",sa[i]);printf(" "); 131 // for(int i=1;i<=cl;i++) printf("%d ",rk[i]);printf(" "); 132 // for(int i=1;i<=cl;i++) 133 // { 134 // for(int j=sa[i];j<=cl;j++) printf("%c",c[j]);printf(" "); 135 // } 136 printf("%I64d ",solve(0)+solve(1)); 137 } 138 return 0; 139 }