【HAOI2016】 找相同字符
子串之类的问题,容易想到后缀数组。最后的问题是,在 \(A,B\) 中各找一个子串,有多少种情况这两个子串相同。
我们分别对 \(A\),\(B\),\(A+B\) 三个串求其后缀数组和 \(height\) 数组。我们可以求出在各自串中相同的子串数量,然后用容斥原理算出答案即可。
那么如何求各自串中相同的子串数量呢?容易想到单调栈。我们令 \(f_i\) 为 \([sa_i,\cdots,n]\) 子串与其后缀有多少个前缀相等。结合单调栈,容易算出答案。
//Don't act like a loser.
//This code is written by huayucaiji
//You can only use the code for studying or finding mistakes
//Or,you'll be punished by Sakyamuni!!!
//#pragma GCC optimize("Ofast","-funroll-loops","-fdelete-null-pointer-checks")
//#pragma GCC target("ssse3","sse3","sse2","sse","avx2","avx")
#include<bits/stdc++.h>
#define int long long
using namespace std;
int read() {
char ch=getchar();
int f=1,x=0;
while(ch<'0'||ch>'9') {
if(ch=='-')
f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9') {
x=x*10+ch-'0';
ch=getchar();
}
return f*x;
}
const int MAXN=4e5+10;
int sa[MAXN],rnk[MAXN],height[MAXN],cnt[MAXN],f[MAXN];
char a[MAXN>>1],b[MAXN>>1],ab[MAXN];
priority_queue<pair<int,int> > q;
struct strpr {
int x,y,id;
bool operator <(const strpr q)const {
if(x!=q.x) {
return x<q.x;
}
return y<q.y;
}
}p[MAXN];
void combine() {
int n=strlen(a+1);
int m=strlen(b+1);
for(int i=1;i<=n;i++) {
ab[i]=a[i];
}
ab[n+1]='$';
for(int i=1;i<=m;i++) {
ab[i+n+1]=b[i];
}
}
void get_sa(char s[]) {
int n=strlen(s+1);
for(int i=1;i<=n;i++) {
rnk[i]=s[i];
}
for(int l=1;l<n;l<<=1) {
for(int i=1;i<=n;i++) {
p[i].x=rnk[i];
p[i].y=(i+l<=n? rnk[i+l]:0);
p[i].id=i;
}
sort(p+1,p+n+1);
int m=0;
for(int i=1;i<=n;i++) {
if(p[i].x!=p[i-1].x||p[i].y!=p[i-1].y) {
m++;
}
rnk[p[i].id]=m;
}
}
for(int i=1;i<=n;i++) {
sa[rnk[i]]=i;
}
return ;
}
void get_height(char s[]) {
int h=0;
int n=strlen(s+1);
for(int i=1;i<=n;i++) {
if(h) {
h--;
}
if(rnk[i]==1) {
continue;
}
int p=i+h;
int q=sa[rnk[i]-1]+h;
while(p<=n&&q<=n&&s[p]==s[q]) {
p++;
q++;
h++;
}
height[rnk[i]]=h;
}
}
int calc(char s[]) {
int n=strlen(s+1);
fill(sa,sa+n+1,0);
fill(rnk,rnk+n+1,0);
fill(height,height+n+1,0);
fill(f,f+n+1,0);
get_sa(s);
get_height(s);
stack<int> stk;
height[n+1]=0;
int sum=0;
for(int i=2;i<=n;++i)
{
while(stk.size()&&height[stk.top()]>=height[i])
stk.pop();
if(stk.empty())
f[i]=height[i]*(i-1);
else
f[i]=f[stk.top()]+height[i]*(i-stk.top());
stk.push(i);
sum+=f[i];
}
return sum;
}
signed main() {
//freopen(".in","r",stdin);
//freopen(".out","w",stdout);
scanf("%s%s",a+1,b+1);
combine();
printf("%lld\n",calc(ab)-calc(a)-calc(b));
//fclose(stdin);
//fclose(stdout);
return 0;
}
/*
abba
bbba
12
*/