考虑反向操作,去计算有多少组相同的子串,对于一组大小为k的极大相同子串的集合,ans-=k-1。
为了避免重复计算,需要一种有效的,有顺序的记录方案。
比如说,对于每一个相同组,按其起始点所在的位置排序,对于除了第一个串以外的串,均记-1的贡献。
但这种东西是非常难以快速统计的。
但是,可以对于每一个相同组,按其所在的后缀字典序排序,对于除了第一个串以外的串,均记-1的贡献。
下面引用别人的一段话,主要是利用lcp来快速统计了不用长度相同组。
========================================================================
每个子串一定是某个后缀的前缀,那么原问题等价于求所有后缀之间的不相同的前缀的个数。
如果所有的后缀按照 suffix(sa[1]), suffix(sa[2]),suffix(sa[3]), …… ,suffix(sa[n])的顺序计算。
不难发现,对于每一次新加进来的后缀 suffix(sa[k]),它将产生 n-sa[k]+1 个新的前缀。
但是其中有height[k]个是和前面的字符串的前缀是相同的。所以 suffix(sa[k])将“贡献”出 n-sa[k]+1- height[k]个不同的子串。
累加后便是原问题的答案。这个做法的时间复杂度为 O(n)。
最后再强调一下为什么只需要统计height[k],而不需要和之前所有的后缀均计算lcp。
因为,按照刚才我们的分析。把每一个相同组看成一条链,计数只能发生在边上。
如果去和前面的再统计一遍的话,显然是一种错误的越级的行为,造成重复统计。
此外,由于按照字典序排序后,再前面的所有串中,与它相邻的串显然是与它lcp最大的串。
一定可以稳稳地不重不漏的对每一个之前每一个出现过的过的前缀进行统计。
即:按照字典序排序后,如果某个 当前后缀的一个前缀 与前面的某个后缀的一个前缀相同。
那么一定是下图这种情况。
红色代表可能的位置,因为字典序的缘故,与它靠的越紧,相似度越高。
所以 要么贡献已经在之前算过了,要么就会体现在它和与它相邻串的lcp中。
#include<iostream>
#include<cctype>
#include<cstdio>
#include<cstring>
#include<string>
#include<cmath>
#include<ctime>
#include<cstdlib>
#include<algorithm>
#define N 1100000
#define L 1000000
#define eps 1e-7
#define inf 1e9+7
#define ll long long
using namespace std;
inline int read()
{
char ch=0;
int x=0,flag=1;
while(!isdigit(ch)){ch=getchar();if(ch=='-')flag=-1;}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return x*flag;
}
char s[N];
int n,m,c[N],x[N],y[N],sa[N],rank[N],height[N];
int main()
{
n=read();m=122;scanf("%s",s+1);
for(int i=1;i<=n;i++)c[x[i]=s[i]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i>=1;i--)sa[c[x[i]]--]=i;
for(int k=1;k<=n;k<<=1)
{
int num=0;
for(int i=n-k+1;i<=n;i++)y[++num]=i;
for(int i=1;i<=n;i++)if(sa[i]>k)y[++num]=sa[i]-k;
for(int i=1;i<=m;i++)c[i]=0;
for(int i=1;i<=n;i++)c[x[i]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i>=1;i--)sa[c[x[y[i]]]--]=y[i],y[i]=0;
swap(x,y);
x[sa[1]]=num=1;
for(int i=2;i<=n;i++)
x[sa[i]]=(y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+k]==y[sa[i]+k])?num:++num;
if(num==n)break;
m=num;
}
ll ans=(ll)n*((ll)n+(ll)1)/(ll)2;
for(int i=1;i<=n;i++)rank[sa[i]]=i;
for(int i=1,k=0;i<=n;i++)
{
if(k)k--;
int j=sa[rank[i]-1];
while(s[i+k]==s[j+k])k++;
height[rank[i]]=k;
ans-=height[rank[i]];
}
printf("%lld",ans);
return 0;
}