POJ_3415
不妨设首字符在第一个字符串里的后缀为A类后缀,首字符在第二个字符串里面的后缀为B类后缀。首先要将两个字符串合并为一个字符串并用分隔符隔开,然后处理出height数组。对于任意一个A类后缀i,和任意一个B类后缀j,假设其公共前缀的长度为k,这两个后缀所能贡献出的S集合里的元素的数目就是k-K+1,当然前提是k>=K。
但是按这样的思路,即便求k时利用height数组的性质及RMQ问题的算法,最后也只能做到O(n^2)复杂度。于是必须优化计算k的过程的时间复杂度。
接下来先要对前面的思路做一个等价的转化,我们选择顺序遍历两次height数组,第一次遍历的时候如果遇到B类后缀,那么就计算一下这个后缀与前面的所有A类后缀对S集合里的元素数目的贡献,第二次遍历的时候如果遇到A类后缀,那么就计算一下这个后缀与前面的所有B类后缀对S集合里的元素数目的贡献。这两部分的和就是最后结果。
接下来的问题就是怎么才能每次都快速地计算出贡献值了,也就是说我们需要维护一个变量t,表示假如遇到当前这个后缀后需要计算一次贡献,那么这个贡献值就是t,在扫描height[]的过程中也就需要不断地更新t。
首先来说,如果将所有相邻的且值不小于K的height[]看成一组的话,那么这个组内的后缀对S的贡献只能是组内的A类与B类后缀的公共前缀产生的,于是我们只需要在一个组的范围内维护一个t即可。
接下来的思路就交由大家思考吧,我也是在看了“维护一个单调栈”的提示后才想到的后面的思路。
#include<stdio.h>
#include<string.h>
#define MAXD 200010
char b[MAXD];
int N, M, K, sa[MAXD], rank[MAXD], height[MAXD], r[MAXD], wa[MAXD], wb[MAXD], ws[MAXD], wv[MAXD];
int s[MAXD], num[MAXD];
void init()
{
int i, j, k = 0;
scanf("%s", b);
for(i = 0; b[i]; i ++, k ++)
r[k] = b[i];
N = i;
r[k ++] = '$';
scanf("%s", b);
for(i = 0; b[i]; i ++, k ++)
r[k] = b[i];
r[M = k] = 0;
}
int cmp(int *p, int x, int y, int l)
{
return p[x] == p[y] && p[x + l] == p[y + l];
}
void da(int n, int m)
{
int i, j, p, *x = wa, *y = wb, *t;
for(i = 0; i < m; i ++)
ws[i] = 0;
for(i = 0; i < n; i ++)
++ ws[x[i] = r[i]];
for(i = 1; i < m; i ++)
ws[i] += ws[i - 1];
for(i = n - 1; i >= 0; i --)
sa[-- ws[x[i]]] = i;
for(j = p = 1; p < n; j *= 2, m = p)
{
for(p = 0, i = n - j; i < n; i ++)
y[p ++] = i;
for(i = 0; i < n; i ++)
if(sa[i] >= j)
y[p ++] = sa[i] - j;
for(i = 0; i < n; i ++)
wv[i] = x[y[i]];
for(i = 0; i < m; i ++)
ws[i] = 0;
for(i = 0; i < n; i ++)
++ ws[wv[i]];
for(i = 1; i < m; i ++)
ws[i] += ws[i - 1];
for(i = n - 1; i >= 0; i --)
sa[-- ws[wv[i]]] = y[i];
for(t = x, x = y, y = t, x[sa[0]] = 0, i = p = 1; i < n; i ++)
x[sa[i]] = cmp(y, sa[i - 1], sa[i], j) ? p - 1 : p ++;
}
}
void calheight(int n)
{
int i, j, k = 0;
for(i = 1; i <= n; i ++)
rank[sa[i]] = i;
for(i = 0; i < n; height[rank[i ++]] = k)
for(k ? -- k : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; k ++);
}
void solve()
{
int i, j, k, top = 0, n;
long long int ans, t;
da(M + 1, 128);
calheight(M);
ans = 0;
for(i = 1; i <= M; i ++)
{
if(height[i] < K)
{
t = 0;
top = -1;
}
else
{
if(sa[i - 1] < N)
{
n = 1;
t += height[i] - K + 1;
}
else
n = 0;
while(top >= 0 && height[i] <= s[top])
{
if(num[top])
{
t -= (long long int)num[top] * (s[top] - K + 1);
t += (long long int)num[top] * (height[i] - K + 1);
n += num[top];
}
-- top;
}
s[++ top] = height[i];
num[top] = n;
if(sa[i] > N)
ans += t;
}
}
for(i = 1; i <= M; i ++)
{
if(height[i] < K)
{
t = 0;
top = -1;
}
else
{
if(sa[i - 1] > N)
{
n = 1;
t += height[i] - K + 1;
}
else
n = 0;
while(top >= 0 && height[i] <= s[top])
{
if(num[top])
{
t -= (long long int)num[top] * (s[top] - K + 1);
t += (long long int)num[top] * (height[i] - K + 1);
n += num[top];
}
-- top;
}
s[++ top] = height[i];
num[top] = n;
if(sa[i] < N)
ans += t;
}
}
printf("%lld\n", ans);
}
int main()
{
for(;;)
{
scanf("%d", &K);
if(!K)
break;
init();
solve();
}
return 0;
}