题面
https://www.luogu.com.cn/problem/P1117
分析
其实朴素暴力就有 95pts ...
设 a[i] 为以第 i 位为结尾的 AA 串个数, b[i] 为以第 i 位开头的 AA 串个数
则答案为 $sum_i^{n-1} a[i] imes b[i]$
考虑用后缀数组优化找相邻后缀相同前缀长度串
考虑枚举 A 的长度 l ,则若从 l 开始,每隔 l 设置一个关键点,则一个 AA 串必然刚好经过两个关键点
设 LCP(i,j) 表示 i 与 j 的后缀的公共前缀长度, LCS(i,j) 表示 i 与 j 的前缀的公共后缀长度
则对于关键点 i 与 i+l ,他们被 AA 串经过当且仅当 LCP(i,i+l) + LCS(i,i+l) >= l
然后 AA 串的起点可以落在 i-LCS+1~i-LCP+l-1 , AA 串的终点可以落在 i+l-LCS+l-1~i+l+LCP 上
差分一下即可,公共前后缀长度用线段树或者ST维护一下height即可
代码
#include <iostream> #include <cstdio> #include <cstring> #define lson (x<<1) #define rson ((x<<1)+1) using namespace std; typedef long long ll; const int N=3e4+10; int n,T,sa[2][N],rk[2][N],height[2][N],c[N],x[N],y[N],t[2][4*N],a[N],b[N]; ll ans; char s[N]; void Suffix_Array(int n,char *s,int *sa,int *rk) { int m='z',cnt=0; memset(x,0,sizeof x);memset(y,0,sizeof y); for (int i=0;i<=m;i++) c[i]=0; for (int i=1;i<=n;i++) c[x[i]=s[i]]++; for (int i=1;i<=m;i++) c[i]+=c[i-1]; for (int i=n;i;i--) sa[c[x[i]]--]=i; for (int j=1;j<=n;j<<=1) { cnt=0; for (int i=n-j+1;i<=n;i++) y[++cnt]=i; for (int i=1;i<=n;i++) if (sa[i]>j) y[++cnt]=sa[i]-j; for (int i=0;i<=m;i++) c[i]=0; for (int i=1;i<=n;i++) c[x[i]]++; for (int i=1;i<=m;i++) c[i]+=c[i-1]; for (int i=n;i;i--) sa[c[x[y[i]]]--]=y[i],y[i]=0; swap(x,y);cnt=x[sa[1]]=1; for (int i=2;i<=n;i++) x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+j]==y[sa[i-1]+j])?cnt:++cnt; m=cnt;if (m==n) break; } for (int i=1;i<=n;i++) rk[sa[i]]=i; } void Get_Height(int n,int *sa,int *rk,int *height) { int z=0; for (int i=1;i<=n;i++) { if (rk[i]==1) {height[rk[i]]=0;continue;} if (z) z--; while (i+z<=n&&sa[rk[i]-1]+z<=n&&s[i+z]==s[sa[rk[i]-1]+z]) z++; height[rk[i]]=z; } } void Build(int x,int l,int r,int id) { if (l==r) {t[id][x]=height[id][l];return;} int mid=l+r>>1; Build(lson,l,mid,id);Build(rson,mid+1,r,id); t[id][x]=min(t[id][lson],t[id][rson]); } int Query(int x,int l,int r,int id,int ll,int rr) { if (ll<=l&&r<=rr) return t[id][x]; int mid=l+r>>1,ans=n+1; if (ll<=mid) ans=Query(lson,l,mid,id,ll,rr); if (mid<rr) ans=min(ans,Query(rson,mid+1,r,id,ll,rr)); return ans; } int main() { for (scanf("%d",&T);T;T--) { scanf("%s",s+1); Suffix_Array(n=strlen(s+1),s,sa[0],rk[0]); Get_Height(n,sa[0],rk[0],height[0]);Build(1,1,n,0); for (int i=1;i*2<=n;i++) swap(s[i],s[n-i+1]); Suffix_Array(n,s,sa[1],rk[1]); Get_Height(n,sa[1],rk[1],height[1]);Build(1,1,n,1); memset(a,0,sizeof a);memset(b,0,sizeof b); for (int i=1;i*2<=n;i++) for (int j=i,k,l,p;j+i<=n;j+=i) { k=min(Query(1,1,n,0,min(rk[0][j],rk[0][j+i])+1,max(rk[0][j],rk[0][j+i])),i); l=min(Query(1,1,n,1,min(rk[1][n-j+2],rk[1][n-j-i+2])+1,max(rk[1][n-j+2],rk[1][n-j-i+2])),i-1); if (k+l<i) continue;p=k+l-i+1; a[j+i+k-p]++;a[j+i+k]--; b[j-l]++;b[j-l+p]--; } for (int i=1;i<=n;i++) a[i]+=a[i-1],b[i]+=b[i-1]; ll ans=0; for (int i=1;i<n;i++) ans+=1ll*a[i]*b[i+1]; printf("%lld ",ans); } }