题目链接:Common Substrings
题意:给两个串s1,s2,求出长度不小于k的公共子串个数
题解:我们先想一个暴力的解法,先把两个串连到一起中间加一个特殊字符。然后求出sa,和lcp,然后n^2枚举两个子串的开始位置,然后对于每两个子串的公共前缀长度L对
答案的贡献是L-k+1;求和就是答案。但是这个是n^2*log.肯定不行
先把答案分成两部分求,第一部分是对于s1的每一个后缀计算字典序比它小的每一个s2的后缀对答案的贡献,第二部分是对于s2的每一个后缀计算字典序比它小的每一个s1的后缀对答案的贡献.这样不会重复也不会遗漏。前面一样连起来求出sa,和lcp。对于然后我们就需要用单调栈。从底到顶单调递增。为什么要用单调栈呢,因为有这样一个性质
那么对于j和k,不妨设rank[j]<rank[k],则有以下性质:
suffix(j)和suffix(k)的最长公共前缀为height[rank[j]+1],height[rank[j]+2],height[rank[j]+3],…,height[rank[k]]中的最小值。
,这样就可以用一个单调栈来保存一个串之前的所有串的与本串的lcp了。当如果有height值小于了当前栈顶的height值,那么大于它的那些只能按照当前这个小的值来计算
要用两个栈一个维护数量和一个维护的height[]值。计算两次就行了。
//#include<bits/stdc++.h> #include<iostream> #include<string> #include<cstring> #include<cstdio> #include<algorithm> #define pb push_back #define ll long long #define PI 3.14159265 #define ls l,m,rt<<1 #define rs m+1,r,rt<<1|1 #define ws wppp #define eps 1e-7 using namespace std; const int N=2e5+5; const int mod=1e9+7; int s[N]; int sa[N], t[N], t2[N], c[N], n; int ran[N], lcp[N]; const int inf=0x3fffffff; void get_sa(int m) { int i, *x = t, *y = t2; for(i = 0; i < m; i++) c[i] = 0; for(i = 0; i < n; i++) c[x[i] = s[i]]++; for(i = 1; i < m; i++) c[i] += c[i-1]; for(i = n-1; i >= 0; i--) sa[--c[x[i]]] = i; for(int k = 1; k <= n; k <<= 1) { int p = 0; for(i = n-k; i < n; i++) y[p++] = i; for(i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i] - k; for(i = 0; i < m; i++) c[i] = 0; for(i = 0; i < n; i++) c[x[y[i]]]++; for(i = 0; i< m; i++) c[i] += c[i-1]; for(i = n-1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i]; swap(x, y); p = 1; x[sa[0]] = 0; for(i = 1; i < n; i++) x[sa[i]] = y[sa[i-1]] == y[sa[i]] && y[sa[i-1]+k] == y[sa[i]+k] ? p-1 : p++; if(p >= n) break; m = p; } int k = 0; for(i = 0; i < n; i++) ran[sa[i]] = i; for(i = 0; i < n; i++) { if(k) k--; int j = sa[ran[i]-1]; while(i+k<n&&j+k<n&&s[i+k] == s[j+k]) k++; lcp[ran[i]] = k; } } int m=0; ll st[N][2]; char s1[N],s2[N]; int main() { while(scanf("%d",&m)&&m) { scanf("%s",s1); scanf("%s",s2); int l1=strlen(s1); int l2=strlen(s2); for(int i=0;i<l1;i++) { s[i]=s1[i]+1; } s[l1]=1; for(int i=0;i<l2;i++) { s[l1+i+1]=s2[i]+1; } n=l1+l2+1; s[n]=0; n++; get_sa(300); // cout<<n<<endl; ll ans=0,sum=0; int tp=0; for(int i=1;i<n;i++) { if(lcp[i]<m)tp=0,sum=0; else { int num=0; while(tp&&lcp[i]<st[tp-1][0]) { sum+=(lcp[i]-st[tp-1][0])*(st[tp-1][1]); num+=st[tp-1][1]; tp--; } st[tp][0]=lcp[i]; if(sa[i-1]>l1) { sum+=lcp[i]-m+1; st[tp++][1]=num+1; } else st[tp++][1]=num; if(sa[i]<l1) { ans+=sum; } } } for(int i=1;i<n;i++) { if(lcp[i]<m)tp=0,sum=0; else { int num=0; while(tp&&lcp[i]<st[tp-1][0]) { sum+=(lcp[i]-st[tp-1][0])*(st[tp-1][1]); num+=st[tp-1][1]; tp--; } st[tp][0]=lcp[i]; if(sa[i-1]<l1) { sum+=lcp[i]-m+1; st[tp++][1]=num+1; } else st[tp++][1]=num; if(sa[i]>l1) { ans+=sum; } } } printf("%lld ",ans); } }