【题目链接】
http://poj.org/problem?id=3415
【题意】
A与B长度至少为k的公共子串个数。
【思路】
基本思想是将AB各个后缀的lcp-k+1的值求和。首先将两个字符串拼接起来中间用未出现的字符隔开,划分height数组,这首先保证了每一组中字符串之间的公共子串至少有k长度,组与组之间互不干扰。
问题变成了求一个组中一个A串与之前B串形成的LCP(lcp-k+1)和一个B串与之前A串形成的LCP,问题是对称的,这里先解决第一个。用一个单调栈,栈中存放两个元素分别height_top与cnt_top,分别表示到i为止的最小height和A串的数目。维护栈中元素的height从顶到底递减:每加入一个元素如果该元素比栈顶元素小则需要将tot中cnt_top个已经累计的height_top全部替换为当前元素的height(lcp是取区间最小值)。
时间复杂度为O(n)。
【代码】
#include<cstdio> #include<cstring> #include<iostream> #define FOR(a,b,c) for(int a=(b);a<=(c);a++) using namespace std; typedef long long LL; const int maxn = 400000 + 10; int s[maxn]; int sa[maxn],c[maxn],t[maxn],t2[maxn]; void build_sa(int m,int n) { int i,*x=t,*y=t2; for(i=0;i<m;i++) c[i]=0; for(i=0;i<n;i++) c[x[i]=s[i]]++; for(i=1;i<m;i++) c[i]+=c[i-1]; for(i=n-1;i>=0;i--) sa[--c[x[i]]]=i; for(int k=1;k<=n;k<<=1) { int p=0; for(i=n-k;i<n;i++) y[p++]=i; for(i=0;i<n;i++) if(sa[i]>=k) y[p++]=sa[i]-k; for(i=0;i<m;i++) c[i]=0; for(i=0;i<n;i++) c[x[y[i]]]++; for(i=0;i<m;i++) c[i]+=c[i-1]; for(i=n-1;i>=0;i--) sa[--c[x[y[i]]]]=y[i]; swap(x,y); p=1; x[sa[0]]=0; for(i=1;i<n;i++) x[sa[i]]=y[sa[i]]==y[sa[i-1]] && y[sa[i]+k]==y[sa[i-1]+k]?p-1:p++; if(p>=n) break; m=p; } } int rank[maxn],height[maxn]; void getHeight(int n) { int i,j,k=0; for(i=0;i<=n;i++) rank[sa[i]]=i; for(i=0;i<n;i++) { if(k) k--; j=sa[rank[i]-1]; while(s[j+k]==s[i+k]) k++; height[rank[i]]=k; } } int n,k; char a[maxn],b[maxn]; int sta[maxn][2]; int main() { while(scanf("%d",&k)==1 && k) { scanf("%s%s",a,b); int lena=strlen(a),lenb=strlen(b); n=0; for(int i=0;i<lena;i++) s[n++]=a[i]; s[n++]=1; for(int i=0;i<lenb;i++) s[n++]=b[i]; s[n]=0; build_sa('z'+1,n+1); getHeight(n); int top=0; LL ans=0,tot=0; for(int i=1;i<=n;i++) { if(height[i]<k) top=0,tot=0; else { int cnt=0; if(sa[i-1]<lena) { cnt++; tot+=height[i]-k+1; } while(top && height[i]<=sta[top-1][0]) { top--; tot+=(height[i]-sta[top][0])*sta[top][1]; cnt+=sta[top][1]; } sta[top][0]=height[i],sta[top++][1]=cnt; if(sa[i]>lena) ans+=tot; } } top=tot=0; for(int i=1;i<=n;i++) { if(height[i]<k) top=0,tot=0; else { int cnt=0; if(sa[i-1]>lena) { cnt++; tot+=height[i]-k+1; } while(top && height[i]<=sta[top-1][0]) { top--; tot+=(height[i]-sta[top][0])*sta[top][1]; cnt+=sta[top][1]; } sta[top][0]=height[i],sta[top++][1]=cnt; if(sa[i]<lena) ans+=tot; } } cout<<ans<<" "; } return 0; }
UPD.16/4/6
【思路】
用字符串A构造SAM,在SAM上匹配第二个字符串B,设当前匹配长度为len,且位于状态p,则当前状态中满足条件长度不小于K的公共子串的字符串个数为
sum = len-max{ K,Min(p) }+1
SAM中一个状态代表的字符串长度为一个连续区间[ Min(s),Max(s) ],Min(s)为最小长度。
这些字符串重复的次数为|right|,即right集的大小,可以递推得到,则当前状态对于答案的贡献为sum*|right|
这时候匹配的是p,还应该统计parent树中p->root的路径上的状态中满足条件的个数。
这里设一个懒标记tag[x],记录节点x需要统计的次数,最后算一遍,每次如果Max(p->fa) >= K则上传标记。
需要注意的是可能出现大写字符 =_=
相比较而言SAM的做法更好想一些。
【代码】
1 #include<set> 2 #include<cmath> 3 #include<queue> 4 #include<vector> 5 #include<cstdio> 6 #include<cstring> 7 #include<iostream> 8 #include<algorithm> 9 #define trav(u,i) for(int i=front[u];i;i=e[i].nxt) 10 #define FOR(a,b,c) for(int a=(b);a<=(c);a++) 11 #define rep(a,b,c) for(int a=(b);a>=(c);a--) 12 using namespace std; 13 14 typedef long long ll; 15 const int N = 2e5+10; 16 17 int K; 18 char A[N],B[N]; 19 20 int get(char c) 21 { 22 if(c>='a'&&c<='z') return c-'a'; 23 else return c-'A'+26; 24 } 25 26 struct SAM 27 { 28 int sz,last; 29 int ch[N][60],fa[N],l[N],c[N],b[N],tag[N]; 30 int siz[N]; ll ans; 31 32 void init() { 33 sz=0; last=++sz; 34 memset(l,0,sizeof(l)); 35 memset(siz,0,sizeof(siz)); 36 memset(fa,0,sizeof(fa)); 37 memset(ch,0,sizeof(ch)); 38 memset(c,0,sizeof(c)); 39 memset(tag,0,sizeof(tag)); 40 } 41 void Add(int c) { 42 int np=++sz,p=last; last=np; 43 l[np]=l[p]+1; siz[np]=1; 44 for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=np; 45 if(!p) fa[np]=1; 46 else { 47 int q=ch[p][c]; 48 if(l[q]==l[p]+1) fa[np]=q; 49 else { 50 int nq=++sz; l[nq]=l[p]+1; 51 memcpy(ch[nq],ch[q],sizeof(ch[q])); 52 fa[nq]=fa[q]; 53 fa[np]=fa[q]=nq; 54 for(;ch[p][c]==q;p=fa[p]) ch[p][c]=nq; 55 } 56 } 57 } 58 void get_right() { 59 FOR(i,1,sz) c[l[i]]++; 60 FOR(i,1,last) c[i]+=c[i-1]; 61 FOR(i,1,sz) b[c[l[i]]--]=i; 62 rep(i,sz,1) siz[fa[b[i]]]+=siz[b[i]]; 63 } 64 ll solve(char* s) { 65 int len=0,p=1; 66 ans=0; 67 for(int i=0;s[i];i++) { 68 int c=get(s[i]); 69 if(ch[p][c]) { 70 len++; p=ch[p][c]; 71 } else { 72 while(p&&!ch[p][c]) p=fa[p]; 73 if(!p) { 74 p=1; len=0; 75 } else { 76 len=l[p]+1; p=ch[p][c]; 77 } 78 } 79 if(len>=K) { 80 ans+=(ll)(len-max(K,l[fa[p]]+1)+1)*siz[p]; 81 if(l[fa[p]]>=K) tag[fa[p]]++; 82 } 83 } 84 rep(j,sz,1) { 85 int i=b[j]; 86 ans+=(ll)tag[i]*(l[i]-max(K,l[fa[i]]+1)+1)*siz[i]; 87 if(l[fa[i]]>=K) tag[fa[i]]+=tag[i]; 88 } 89 return ans; 90 } 91 92 }sam; 93 94 int main() 95 { 96 while(scanf("%d",&K)==1 && K) { 97 scanf("%s%s",A,B); 98 sam.init(); 99 for(int i=0;A[i];i++) 100 sam.Add(get(A[i])); 101 sam.get_right(); 102 printf("%lld ",sam.solve(B)); 103 } 104 return 0; 105 }