题意
这题正解还是有点难想的。
显然满足条件的字符串就是两个(AA)样子的串一前一后。
设(a_i)表示以(i)为开头的(AA)串长度,(a_i)表示以(i)为结尾的(AA)串长度,那么答案显然为:
(sumlimits_{i=1}^{n-1}a_{i+1}*b_i)
于是考虑怎么求这个,我们考虑枚举(AA)的长度(len),对原串每隔(len)设一个关键点,因为长为(len)的(AA)串必定过且只过两个关键点,因此对于每对相邻的关键点,我们求出它们随对应的(AA)串。
求出(lcp)表示([c_{i},n])和([c_{i+1},n])((c_i)表示第(i)个关键点)的最长公共前缀,(lcs)表示([1,c_{i-1}-1])和([1,c_{i+1}-1])的最长公共后缀。
当(lcs+lcp<len)时,如下图:
我们发现并不会有(AA)串过它们。
当(lcp+lcsgeqslant len)时:
我们设(t=lcp-lcs-len+1)
我们发现在前面的(t)长度的点都可以作为一个(AA)串的开头,后面(t)长度的点都可以作为一个(AA)串的结尾。
于是我们要区间加(1),这个差分就好了。
code:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=30010;
int T,n;
int a[maxn],b[maxn],c[maxn],lg[maxn];
ll ans;
char s[maxn];
struct SA
{
int n,num;
int sa[maxn],rk[maxn],oldrk[maxn],id[maxn],tmpid[maxn],cnt[maxn];
int height[maxn][20];
char s[maxn];
inline void clear()
{
memset(sa,0,sizeof(sa));
memset(rk,0,sizeof(rk));
memset(height,0x3f,sizeof(height));
memset(s,0,sizeof(s));//一定要清空字符串。
}
inline bool cmp(int x,int y,int k){return oldrk[x]==oldrk[y]&&oldrk[x+k]==oldrk[y+k];}
inline void build()
{
num=300;
memset(cnt,0,sizeof(cnt));
for(int i=1;i<=n;i++)cnt[rk[i]=s[i]]++;
for(int i=1;i<=num;i++)cnt[i]+=cnt[i-1];
for(int i=n;i;i--)sa[cnt[rk[i]]--]=i;
for(int t=1;t<=n;t<<=1)
{
int tot=0;
for(int i=n-t+1;i<=n;i++)id[++tot]=i;
for(int i=1;i<=n;i++)if(sa[i]>t)id[++tot]=sa[i]-t;
tot=0;
memset(cnt,0,sizeof(cnt));
for(int i=1;i<=n;i++)cnt[tmpid[i]=rk[id[i]]]++;
for(int i=1;i<=num;i++)cnt[i]+=cnt[i-1];
for(int i=n;i;i--)sa[cnt[tmpid[i]]--]=id[i];
memcpy(oldrk,rk,sizeof(rk));
for(int i=1;i<=n;i++)rk[sa[i]]=cmp(sa[i-1],sa[i],t)?tot:++tot;
num=tot;
if(num>=n)break;
}
for(int i=1,j=0;i<=n;i++)
{
if(j)j--;
while(s[i+j]==s[sa[rk[i]-1]+j])j++;
height[rk[i]][0]=j;
}
for(int j=1;j<=18;j++)
for(int i=1;i+(1<<j)-1<=n;i++)
height[i][j]=min(height[i][j-1],height[i+(1<<(j-1))][j-1]);
}
inline int query(int x,int y)
{
x=rk[x],y=rk[y];
if(x>y)swap(x,y);x++;
int t=lg[y-x+1];
return min(height[x][t],height[y-(1<<t)+1][t]);
}
}Sa[2];
inline void init()
{
memset(a,0,sizeof(a));
memset(b,0,sizeof(b));
Sa[0].clear(),Sa[1].clear();
ans=0;
}
inline void solve()
{
scanf("%s",s+1);n=strlen(s+1);
Sa[0].n=Sa[1].n=n;
for(int i=1;i<=n;i++)Sa[0].s[i]=Sa[1].s[n-i+1]=s[i];
Sa[0].build(),Sa[1].build();
for(int len=1;len<=n/2;len++)
{
int tot=0;
for(int i=len;i<=n;i+=len)c[++tot]=i;
for(int i=1;i<tot;i++)
{
int lcp=min(Sa[0].query(c[i],c[i+1]),len),lcs=min(Sa[1].query(n-c[i]+2,n-c[i+1]+2),len-1);
if(lcp+lcs<len)continue;
int t=lcp+lcs-len+1;
a[c[i]-lcs]++,a[c[i]-lcs+t]--;
b[c[i+1]+lcp-t]++,b[c[i+1]+lcp]--;
}
}
for(int i=1;i<=n;i++)a[i]+=a[i-1],b[i]+=b[i-1];
for(int i=1;i<n;i++)ans+=1ll*a[i+1]*b[i];
printf("%lld
",ans);
}
int main()
{
//freopen("test.in","r",stdin);
//freopen("test.out","w",stdout);
lg[0]=-1;
for(int i=1;i<=30000;i++)lg[i]=lg[i>>1]+1;
scanf("%d",&T);
while(T--)
{
init();
solve();
}
return 0;
}