题目大意:
给定两个串,求有多少种方式从两个串中各提取出一个子串并且两个子串相等。
思路:
涉及两个串的子串问题考虑对第一个串建立SAM。
然后用第个二串在SAM上匹配,每到一个点,贡献是(目前的长度-这个状态的父亲的长度)x这个状态RIGHT集合的大小,同时对这个状态的每个祖先也像这样计算贡献即可。
/*=======================================
* Author : ylsoi
* Time : 2019.2.14
* Problem : bzoj4566
* E-mail : ylsoi@foxmail.com
* ====================================*/
#include<bits/stdc++.h>
#define REP(i,a,b) for(int i=a,i##_end_=b;i<=i##_end_;++i)
#define DREP(i,a,b) for(int i=a,i##_end_=b;i>=i##_end_;--i)
#define debug(x) cout<<#x<<"="<<x<<" "
#define fi first
#define se second
#define mk make_pair
#define pb push_back
typedef long long ll;
using namespace std;
void File(){
freopen("bzoj4566.in","r",stdin);
freopen("bzoj4566.out","w",stdout);
}
template<typename T>void read(T &_){
_=0; T f=1; char c=getchar();
for(;!isdigit(c);c=getchar())if(c=='-')f=-1;
for(;isdigit(c);c=getchar())_=(_<<1)+(_<<3)+(c^'0');
_*=f;
}
const int maxn=2e5+10;
int n1,n2;
char s1[maxn],s2[maxn];
int len[maxn<<1],fa[maxn<<1],ch[maxn<<1][26];
int cnt=1,last=1,sz[maxn<<1];
ll sum[maxn<<1],ans;
void insert(int x){
int p=last,np=last=++cnt;
len[np]=len[p]+1;
sz[np]=1;
while(p && !ch[p][x])ch[p][x]=np,p=fa[p];
if(!p)fa[np]=1;
else{
int q=ch[p][x];
if(len[q]==len[p]+1)fa[np]=q;
else{
int nq=++cnt;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
len[nq]=len[p]+1,fa[nq]=fa[q];
fa[q]=fa[np]=nq;
while(p && ch[p][x]==q)ch[p][x]=nq,p=fa[p];
}
}
}
void get_sz(){
int tax[maxn<<1]={0},lis[maxn<<1]={0};
REP(i,1,cnt)++tax[len[i]];
REP(i,1,n1)tax[i]+=tax[i-1];
REP(i,1,cnt)lis[tax[len[i]]--]=i;
DREP(i,cnt,1)sz[fa[lis[i]]]+=sz[lis[i]];
REP(i,1,cnt)sum[i]=1ll*(len[i]-len[fa[i]])*sz[i];
REP(i,1,cnt)sum[lis[i]]+=sum[fa[lis[i]]];
}
void compare(){
int o=1,now=0;
REP(i,1,n2){
int x=s2[i]-'a';
while(o!=1 && !ch[o][x])o=fa[o],now=len[o];
if(ch[o][x]){
o=ch[o][x];
++now;
ans+=1ll*(now-len[fa[o]])*sz[o];
ans+=sum[fa[o]];
}
}
}
int st[maxn],tp;
void dfs(int o){
REP(i,1,tp)printf("%c",st[i]+'a');
printf("
");
REP(i,0,25)if(ch[o][i]){
st[++tp]=i;
dfs(ch[o][i]);
--tp;
}
}
int main(){
File();
scanf("%s%s",s1+1,s2+1);
n1=strlen(s1+1),n2=strlen(s2+1);
REP(i,1,n1){
insert(s1[i]-'a');
}
get_sz();
compare();
printf("%lld
",ans);
return 0;
}