题目要求求出两个两个字符串中相同子串的方案数,那么我们将其拼接起来,去求出拼接后的字符串中含有相同子串的数量。
当然这样做会求出同一个字符串中相同子串的数量,所以我们还需要如法炮制分别求出两个字符串中的答案,然后用总贡献减去他们。
那么问题就变成了如何求出一个字符串中相同子串的数量。
实际上这就是求任意两个后缀x,y的lcp(x,y)和,利用后缀数组,可以知道任意两个lcp(x,y)为min{Height[i]} x <= i <= y
所以说问题在利用后缀数组求出Height数组之后转化为了如何求Height所有子区间内最小值的和。
事实上这是一个经典的利用DP做的问题,求解一个序列中任意子区间的最小值的和,设置L[i]表示最左端的大于a[i]的下标,R[i]表示最右端的大于等于a[i]的下标,注意需要左闭右开,然后在求出来之后对于每个位置,对答案的贡献就是 a[i] * (i - L[i] + 1) * (R[i] - i + 1)
注意事项:在拼接字符串的过程中,需要在两字符串中间增加一个'z' + 1的字符,防止出现取到的相同子串跨越两个被拼接的字符串
#include <map> #include <set> #include <ctime> #include <cmath> #include <queue> #include <stack> #include <vector> #include <string> #include <bitset> #include <cstdio> #include <cstdlib> #include <cstring> #include <sstream> #include <iostream> #include <algorithm> #include <functional> using namespace std; #define For(i, x, y) for(int i=x;i<=y;i++) #define _For(i, x, y) for(int i=x;i>=y;i--) #define Mem(f, x) memset(f,x,sizeof(f)) #define Sca(x) scanf("%d", &x) #define Sca2(x,y) scanf("%d%d",&x,&y) #define Sca3(x,y,z) scanf("%d%d%d",&x,&y,&z) #define Scl(x) scanf("%lld",&x) #define Pri(x) printf("%d ", x) #define Prl(x) printf("%lld ",x) #define CLR(u) for(int i=0;i<=N;i++)u[i].clear(); #define LL long long #define ULL unsigned long long #define mp make_pair #define PII pair<int,int> #define PIL pair<int,long long> #define PLL pair<long long,long long> #define pb push_back #define fi first #define se second typedef vector<int> VI; int read(){int x = 0,f = 1;char c = getchar();while (c<'0' || c>'9'){if (c == '-') f = -1;c = getchar();} while (c >= '0'&&c <= '9'){x = x * 10 + c - '0';c = getchar();}return x*f;} const double PI = acos(-1.0); const double eps = 1e-9; const int maxn = 4e5 + 10; const int INF = 0x3f3f3f3f; const int mod = 1e9 + 7; const int SP = 20; int M,K,l1,l2,l; char str[maxn]; int sa[maxn],rak[maxn],tex[maxn],tp[maxn],Height[maxn]; void Qsort(int N){ for(int i = 0; i <= M ; i ++) tex[i] = 0; for(int i = 1; i <= N ; i ++) tex[rak[i]]++; for(int i = 1; i <= M ; i ++) tex[i] += tex[i - 1]; for(int i = N; i >= 1 ; i --) sa[tex[rak[tp[i]]]--] = tp[i]; } void SA(char *str,int N){ for(int i = 1; i <= N ; i ++) rak[i] = str[i] - '0' + 1,tp[i] = i; Qsort(N); for(int w = 1,p = 0; p < N; M = p, w <<= 1){ p = 0; for(int i = 1; i <= w; i ++) tp[++p] = N - w + i; for(int i = 1; i <= N ; i ++) if(sa[i] > w) tp[++p] = sa[i] - w; Qsort(N); swap(rak,tp); rak[sa[1]] = p = 1; for(int i = 2; i <= N ; i ++){ rak[sa[i]] = (tp[sa[i]] == tp[sa[i - 1]] && tp[sa[i] + w] == tp[sa[i - 1] + w])?p:++p; } } } void GetHeight(char *str,int N){ int j,k = 0; for(int i = 1; i <= N ; i ++){ if(k) k--; int j = sa[rak[i] - 1]; while(i + k <= N && j + k <= N && str[i + k] == str[j + k]) k++; Height[rak[i]] = k; } } LL L[maxn],R[maxn]; LL work(int* a,int n){ for(int i = 1; i <= n ; i ++){ L[i] = i; while(L[i] != 1 && a[i] < a[L[i] - 1]) L[i] = L[L[i] - 1]; } for(int i = n; i >= 1; i --){ R[i] = i; while(R[i] != n && a[i] <= a[R[i] + 1]) R[i] = R[R[i] + 1]; } LL ans = 0; for(int i = 1; i <= n ; i ++){ ans += a[i] * ((R[i] - i + 1) * (i - L[i] + 1)); } return ans; } int main(){ scanf("%s",str + 1); l1 = strlen(str + 1); str[l1 + 1] = 'z' + 1; scanf("%s",str + 2 + l1); l2 = strlen(str + 2 + l1); l = l1 + l2 + 1; M = 122; SA(str,l); LL ans = 0; GetHeight(str,l); ans += work(Height + 1,l - 1); M = 122; SA(str,l1); GetHeight(str,l1); ans -= work(Height + 1,l1 - 1); M = 122; SA(str + l1 + 1,l2); GetHeight(str + l1 + 1,l2); ans -= work(Height + 1,l2 - 1); Prl(ans); return 0; }