借用罗穗骞论文中的讲解:
计算A 的所有后缀和B 的所有后缀之间的最长公共前缀的长度,把最长公共前缀长度不小于k 的部分全部加起来。先将两个字符串连起来,中间用一个没有出现过的字符隔开。按height 值分组后,接下来的工作便是快速的统计每组中后缀之间的最长公共前缀之和。扫描一遍,每遇到一个B 的后缀就统计与前面的A 的后缀能产生多少个长度不小于k 的公共子串,这里A 的后缀需要用一个单调的栈来高效的维护。然后对A 也这样做一次。
#include<cstdio> #include<iostream> #include<algorithm> #include<cstring> #include<cstdlib> using namespace std; const int N = 210008; typedef long long LL; int val[N], sum[N], wa[N], wb[N]; int sa[N], rk[N], height[N]; char a[N], b[N]; inline bool cmp(int str[], int a, int b, int l){ return str[a] == str[b] && str[a + l] == str[b + l]; } void da(char str[], int n, int m){ int *x = wa, *y = wb; memset(sum, 0, sizeof(sum)); for(int i = 0; i < n; i++){ sum[x[i] = str[i]]++; } for(int i = 1; i < m; i++){ sum[i] += sum[i - 1]; } for(int i = n - 1; i >= 0; i--){ sa[--sum[ x[i]] ] = i; } for(int j = 1, p = 1; p < n; j *= 2, m = p){ p = 0; for(int i = n - j; i < n; i++){ y[p++] = i; } for(int i = 0; i < n; i++){ if(sa[i] >= j){ y[p++] = sa[i] - j; } } for(int i = 0; i < n; i++){ val[i] = x[ y[i] ]; } memset(sum , 0, sizeof(sum)); for(int i = 0; i < n; i++){ sum[val[i]]++; } for(int i = 1; i < m; i++){ sum[i] += sum[i - 1]; } for(int i = n - 1; i >= 0; i--){ sa[--sum[ val[i] ]] = y[i]; } swap(x, y); x[sa[0]] = 0; p = 1; for(int i = 1 ; i < n; i++){ x[sa[i]] = cmp(y, sa[i - 1], sa[i], j)? p - 1:p++; } } } void getHeight(char str[], int n){ for(int i = 1; i <= n; i++){ rk[ sa[i] ] = i; } int k = 0; for(int i = 0; i < n; height[rk[i++]] = k){ if(k) k--; int j = sa[rk[i] - 1]; while(str[i + k] == str[j + k]){ k++; } } } struct node{ int h; LL cnt; }stk[N]; int main(){ int k; while(~scanf("%d", &k) && k){ scanf("%s %s", a, b); int n = strlen(a); int m = strlen(b); int len = n + m + 1; a[n] = 125; for(int i = n + 1, j = 0; j < m; i++, j++){ a[i] = b[j]; } a[len] = 0; da(a, len + 1, 150); getHeight(a, len); LL sum = 0; int top = 0; LL tot = 0; for(int i = 1; i <= len ; i++){ int cnt = 0; if(height[i] < k){ top = 0; tot = 0; }else{ if(sa[i - 1] < n){ cnt++; tot += height[i] - k + 1; } while(top > 0 && stk[top - 1].h > height[i]){ top--; cnt += stk[top].cnt; tot -= (stk[top].h - height[i]) * stk[top].cnt; } stk[top].h = height[i]; stk[top].cnt = cnt; top++; if(sa[i] > n){ sum += tot; } } } top = 0; tot = 0; for(int i = 1; i <= len ; i++){ int cnt = 0; if(height[i] < k){ top = 0; tot = 0; }else{ if(sa[i - 1] > n){ cnt++; tot += height[i] - k + 1; } while(top > 0 && stk[top - 1].h > height[i]){ top--; cnt += stk[top].cnt; tot -= (stk[top].h - height[i]) * stk[top].cnt; } stk[top].h = height[i]; stk[top].cnt = cnt; top++; if(sa[i] < n){ sum += tot; } } } printf("%I64d ", sum); } return 0; }