Description
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。
Solution
题中所求即为A串与B串任意一组height之和,所以暴力可以做到$O(n^2)$
考虑将答案挂在这两组后缀中sa值较大的那个上,即每个后缀只考虑出现在它前的所有后缀贡献的答案
由于lcp长度为在这两排名间height的最小值,可以按排名枚举后缀,单调栈维护答案
可以做到$O(n log n)$或$O(n)$
#include<iostream> #include<cstring> #include<cstdio> #include<stack> using namespace std; int n,m=127,len1,len2,buc[400010],x[400010],y[400010],sa[400010],rk[400010],height[400010],sum[400010]; long long ans; char s[400010],s2[200005]; struct Node { int id; long long val; }; stack<Node>sta; inline int read() { int w=0,f=1; char ch=0; while(ch<'0'||ch>'9'){if(ch=='-') f=-1; ch=getchar();} while(ch>='0'&&ch<='9')w=(w<<1)+(w<<3)+ch-'0',ch=getchar(); return w*f; } void getsa() { for(int i=1;i<=n;i++) ++buc[x[i]=s[i]]; for(int i=2;i<=m;i++) buc[i]+=buc[i-1]; for(int i=n;i;i--) sa[buc[x[i]]--]=i; for(int k=1;k<=n;k<<=1) { int num=0; for(int i=n-k+1;i<=n;i++) y[++num]=i; for(int i=1;i<=n;i++) if(sa[i]>k) y[++num]=sa[i]-k; for(int i=1;i<=m;i++) buc[i]=0; for(int i=1;i<=n;i++) ++buc[x[i]]; for(int i=2;i<=m;i++) buc[i]+=buc[i-1]; for(int i=n;i;i--) sa[buc[x[y[i]]]--]=y[i],y[i]=0; swap(x,y),x[sa[1]]=1,num=1; for(int i=2;i<=n;i++) x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?num:++num; if(num==n) break; m=num; } } void getheight() { int k=0; for(int i=1;i<=n;i++) rk[sa[i]]=i; for(int i=1;i<=n;i++) { if(rk[i]==1) continue; if(k) --k; int j=sa[rk[i]-1]; while(i+k<=n&&j+k<=n&&s[i+k]==s[j+k]) ++k; height[rk[i]]=k; } } int main() { scanf("%s",s+1),len1=strlen(s+1),s[len1+1]='z'+1,scanf("%s",s2+1),len2=strlen(s2+1); for(int i=1;i<=len2;i++) s[len1+1+i]=s2[i]; n=strlen(s+1),getsa(),getheight(),sta.push((Node){1,0}); for(int i=1;i<=n;i++) sum[i]=sum[i-1]+(sa[i]<=len1); for(int i=2;i<=n;i++) { while(sta.size()&&height[sta.top().id]>height[i]) sta.pop(); Node temp=sta.top(); sta.push((Node){i,(sum[i-1]-sum[temp.id-1])*height[i]+temp.val}); if(sa[i]>len1+1) ans+=sta.top().val; } while(sta.size()) sta.pop(); sta.push((Node){1,0}); for(int i=1;i<=n;i++) sum[i]=sum[i-1]+(sa[i]>len1+1); for(int i=2;i<=n;i++) { while(sta.size()&&height[sta.top().id]>height[i]) sta.pop(); Node temp=sta.top(); sta.push((Node){i,(sum[i-1]-sum[temp.id-1])*height[i]+temp.val}); if(sa[i]<=len1) ans+=sta.top().val; } printf("%lld ",ans); return 0; }