poj3415
题意
给定两个字符串,给出长度 (m) ,问这两个字符串有多少对长度大于等于 (m) 且完全相同的子串。
分析
首先连接两个字符串 A B,中间用一个特殊符号分割开。
按照 (sa) 的顺序(即枚举 (height) 值),进行分组,那么有公共前缀长大于等于 (m) 的都分到了一组,对于某一组,后缀串可能来自于 A 也可能来自于 B,那么对于 A 找前面的 B 串,对于 B 找前面的 A 串,如果某两个后缀串的公共前缀长为 (l(l geqslant m)),那么显然会有 (l - m + 1) 对子串。
注意到这个性质: 对于两个后缀串 j 和 k,设 (rnk[j] < rnk[k]) ,LCP长度为 (height[rnk[j]+1], height[rnk[j]+2], ... , height[rnk[k]]) 中的最小值。
维护一个单调递增的栈(保证栈顶最大)可以用一个二维数组表示((q[][2])),一个是栈,一个是某个数的个数。
举个例子,如果连续的 (height) 值为 (2 3 4) ,(m = 2),前三个为 A 串,那么 (2 3 4) 全部入栈,且计算对答案的贡献 (sum)(不是直接加到答案上),即 ((2-2+1) + (3-2+1) + (4-2+1)) ,到 B 串时,答案就加上了这个值,但是如果后面还有一个 B 串且 (height) 为 (3),那么就要弹栈,且减去 (sum) 值多的那部分(前面多算了),栈里 (4) 的数量为 (1),所以 (sum = sum - (4 - 3) * 1) ,且栈里 (3) 的数量变为了 (2) ( (4) 对应的 A 串对于后面串提供的贡献减小了(注意前面的性质),所以(4) 变为了 (3) ),答案加上 (sum)。
code
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int MAXN = 2e5 + 10;
const int INF = 1e9;
char s[MAXN];
int sa[MAXN], t[MAXN], t2[MAXN], c[MAXN], n; // n 为 字符串长度 + 1,即最后一位为数字 0
int rnk[MAXN], height[MAXN];
// 构造字符串 s 的后缀数组。每个字符值必须为 0 ~ m-1
void build_sa(int m) {
int i, *x = t, *y = t2;
for(i = 0; i < m; i++) c[i] = 0;
for(i = 0; i < n; i++) c[x[i] = s[i]]++;
for(i = 1; i < m; i++) c[i] += c[i - 1];
for(i = n - 1; i >= 0; i--) sa[--c[x[i]]] = i;
for(int k = 1; k <= n; k <<= 1) {
int p = 0;
for(i = n - k; i < n; i++) y[p++] = i;
for(i = 0; i < n; i++) if(sa[i] >= k) y[p++] = sa[i] - k;
for(i = 0; i < m; i++) c[i] = 0;
for(i = 0; i < n; i++) c[x[y[i]]]++;
for(i = 0; i < m; i++) c[i] += c[i - 1];
for(i = n - 1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i];
swap(x, y);
p = 1;
x[sa[0]] = 0;
for(i = 1; i < n; i++)
x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++;
if(p >= n) break;
m = p;
}
}
void getHeight() {
int i, j, k = 0;
for(i = 0; i < n; i++) rnk[sa[i]] = i;
for(i = 0; i < n - 1; i++) {
if(k) k--;
j = sa[rnk[i] - 1];
while(s[i + k] == s[j + k]) k++;
height[rnk[i]] = k;
}
}
char s2[MAXN];
int q[MAXN][2];
int main() {
int m;
while(~scanf("%d", &m) && m) {
scanf("%s%s", s, s2); // A 、B串
int l = strlen(s), l2 = strlen(s2);
s[l++] = '#';
for(int i = 0; i < l2; i++) s[i + l] = s2[i];
s[l + l2] = 0;
n = l + l2 + 1;
build_sa(128);
getHeight();
ll ans = 0, sum = 0;
int top = 0;
// 在 B 串前找 A
for(int i = 2; i < n; i++) {
int cnt = 0;
if(height[i] < m) {
top = 0; sum = 0;
continue;
}
if(sa[i - 1] < l) {
cnt++;
sum += height[i] - m + 1;
}
while(top && q[top - 1][0] >= height[i]) {
top--;
sum -= (q[top][0] - height[i]) * q[top][1];
cnt += q[top][1];
}
q[top][0] = height[i]; q[top++][1] = cnt;
if(sa[i] >= l) ans += sum;
}
// 在 A 串前找 B
sum = 0; top = 0;
for(int i = 2; i < n; i++) {
int cnt = 0;
if(height[i] < m) {
top = 0; sum = 0;
continue;
}
if(sa[i - 1] >= l) {
cnt++;
sum += height[i] - m + 1;
}
while(top && q[top - 1][0] >= height[i]) {
top--;
sum -= (q[top][0] - height[i]) * q[top][1];
cnt += q[top][1];
}
q[top][0] = height[i]; q[top++][1] = cnt;
if(sa[i] < l) ans += sum;
}
printf("%lld
", ans);
}
return 0;
}