【BZOJ3238】[AHOI2013]差异
题面
给定字符串(S),令(T_i)表示以它从第(i)个字符开始的后缀。求
[sum_{1leq i<jleq n}len(T_i)+len(T_j)-2*lcp(T_i,T_j)
]
其中(len(a))表示串(a)的长度,(lcp(a,b))表示串(a,b)的最长公共前缀
题解
把这个式子看作两边分开求:
Part1:
[sum_{1leq i<jleq n}len(T_i)+len(T_j)\
Leftrightarrowsum_{i=1}^{n-1}sum_{j=i+1}^ni+j\
=sum_{i=1}^{n-1}i*(n-i)+frac{(i+1+n)*(n-i)}2
]
其实现在你就可以(O(n))地求了,但是因为我出(kan)于(le)美(ti)观(jie)
发现它其实可以化成这样:
[frac {(n-1)*n*(n+1)}2
]
Part2:
一看到后缀当然是(sa)啦
由后缀数组的性质,排名为分别为(i,j)的后缀,(lcp_{i,j}=minlimits_{k=i+1}^jheight_k)
我们将所有高度数组排成一排,
假设中间的第(i)个数是(l-r)中最小的
则它的贡献就是((i-l+1)*(r-i+1))
我们处理出来对(i)所有的(l,r)是不是就做出来了呢
这不就是一个单调栈的经典应用吗?
而这个题目中因为一些细节问题我的(l)表示小于(i)的第一个,(r)表示小于等于(i)的第一个
详见代码:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
inline int gi() {
register int data = 0, w = 1;
register char ch = 0;
while (!isdigit(ch) && ch != '-') ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (isdigit(ch)) data = 10 * data + ch - '0', ch = getchar();
return w * data;
}
const int MAX_N = 5e5 + 5;
int N; char a[MAX_N];
int sa[MAX_N], lcp[MAX_N], rnk[MAX_N];
void GetSA() {
#define cmp(i, j, k) (y[i] == y[j] && y[i + k] == y[j + k])
static int x[MAX_N], y[MAX_N], bln[MAX_N];
int M = 122;
for (int i = 1; i <= N; i++) bln[x[i] = a[i]]++;
for (int i = 1; i <= M; i++) bln[i] += bln[i - 1];
for (int i = N; i >= 1; i--) sa[bln[x[i]]--] = i;
for (int k = 1; k <= N; k <<= 1) {
int p = 0;
for (int i = 0; i <= M; i++) y[i] = 0;
for (int i = N - k + 1; i <= N; i++) y[++p] = i;
for (int i = 1; i <= N; i++) if (sa[i] > k) y[++p] = sa[i] - k;
for (int i = 0; i <= M; i++) bln[i] = 0;
for (int i = 1; i <= N; i++) bln[x[y[i]]]++;
for (int i = 1; i <= M; i++) bln[i] += bln[i - 1];
for (int i = N; i >= 1; i--) sa[bln[x[y[i]]]--] = y[i];
swap(x, y); x[sa[1]] = p = 1;
for (int i = 2; i <= N; i++) x[sa[i]] = cmp(sa[i], sa[i - 1], k) ? p : ++p;
if (p >= N) break;
M = p;
}
}
void GetLcp() {
for (int i = 1; i <= N; i++) rnk[sa[i]] = i;
for (int i = 1, j = 0; i <= N; i++) {
if (j) --j;
while (a[i + j] == a[sa[rnk[i] - 1] + j]) ++j;
lcp[rnk[i]] = j;
}
}
typedef long long ll;
int lp[MAX_N], rp[MAX_N], stk[MAX_N], top;
int main () {
scanf("%s", a + 1); N = strlen(a + 1);
GetSA(); GetLcp();
ll ans = 0;
//for (int i = 1; i < N; i++) ans += 1ll * i * (N - i) + 1ll * (i + 1 + N) * (N - i) / 2ll;
ans = 1ll * N * (N + 1) * (N - 1) / 2ll;
stk[0] = 1;
for (int i = 2; i <= N; i++) {
while (top > 0 && lcp[stk[top]] >= lcp[i]) --top;
lp[i] = i - stk[top], stk[++top] = i;
}
top = 0, stk[0] = N + 1;
for (int i = N; i >= 2; i--) {
while (top > 0 && lcp[stk[top]] > lcp[i]) --top;
rp[i] = stk[top] - i, stk[++top] = i;
}
for (int i = 2; i <= N; i++) ans -= 2ll * lcp[i] * lp[i] * rp[i];
printf("%lld
", ans);
return 0;
}