(color{#0066ff}{ 题目描述 })
给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。
(color{#0066ff}{输入格式})
两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母
(color{#0066ff}{输出格式})
输出一个整数表示答案
(color{#0066ff}{输入样例})
aabb
bbaa
(color{#0066ff}{输出样例})
10
(color{#0066ff}{数据范围与提示})
none
(color{#0066ff}{ 题解 })
考虑把两个串拼起来,中间隔一个无关字符
我们每次找到一个合法的LCP,显然会产生LCP所有字串的贡献,但是这样会重复
我们定住一个端点,也就是让它产生LCP长度的贡献,这样在不同后缀中一端不同,相同后缀中另一端不同
怎么统计呢?
考虑单步容斥,用拼好的串的贡献-两个串内部贡献
#include<bits/stdc++.h>
#define LL long long
LL in() {
char ch; LL x = 0, f = 1;
while(!isdigit(ch = getchar()))(ch == '-') && (f = -f);
for(x = ch ^ 48; isdigit(ch = getchar()); x = (x << 1) + (x << 3) + (ch ^ 48));
return x * f;
}
const int inf = 0x7fffffff;
const int maxn = 4e5 + 5;
struct SA {
protected:
int x[maxn], y[maxn], rk[maxn], sa[maxn], c[maxn], st[maxn];
int top, n, m;
LL l[maxn], r[maxn], h[maxn];
public:
void operator () (char *s, int len) {
n = len, m = 122;
for(int i = 1; i <= n; i++) c[x[i] = s[i]]++;
for(int i = 1; i <= m; i++) c[i] += c[i - 1];
for(int i = n; i >= 1; i--) sa[c[x[i]]--] = i;
for(int k = 1; k <= n; k <<= 1) {
int num = 0;
for(int i = n - k + 1; i <= n; i++) y[++num] = i;
for(int i = 1; i <= n; i++) if(sa[i] > k) y[++num] = sa[i] - k;
for(int i = 1; i <= m; i++) c[i] = 0;
for(int i = 1; i <= n; i++) c[x[i]]++;
for(int i = 1; i <= m; i++) c[i] += c[i - 1];
for(int i = n; i >= 1; i--) sa[c[x[y[i]]]--] = y[i], y[i] = 0;
std::swap(x, y);
x[sa[1]] = 1, num = 1;
for(int i = 2; i <= n; i++)
x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k])? num : ++num;
if(n == num) break;
m = num;
}
for(int i = 1; i <= n; i++) rk[i] = x[i];
int H = 0;
for(int i = 1; i <= n; i++) {
if(rk[i] == 1) continue;
if(H) H--;
int j = sa[rk[i] - 1];
while(i + H <= n && j + H <= n && s[j + H] == s[i + H]) H++;
h[rk[i]] = H;
}
}
LL getans() {
LL ans = 0;
h[0] = h[n + 1] = -inf;
st[top = 1] = 0;
for(int i = 1; i <= n; i++) {
while(h[i] <= h[st[top]]) top--;
l[i] = st[top];
st[++top] = i;
}
st[top = 1] = n + 1;
for(int i = n; i >= 1; i--) {
while(h[i] < h[st[top]]) top--;
r[i] = st[top];
st[++top] = i;
}
for(LL i = 1; i <= n; i++) ans += (r[i] - i) * (i - l[i]) * h[i];
return ans;
}
}a, b, c;
char s[maxn], t[maxn];
int main() {
scanf("%s", s + 1);
scanf("%s", t + 1);
int lens = strlen(s + 1);
int lent = strlen(t + 1);
a(s, lens);
b(t, lent);
s[++lens] = '#';
for(int i = 1; i <= lent; i++) s[++lens] = t[i];
c(s, lens);
printf("%lld
", c.getans() - a.getans() - b.getans());
return 0;
}