可以发现,对于原串的每个长度>1的子串而言,将其除了最后一个字符之外反向接在其结尾,都是一个合法解。该解的长度一定是奇数。
对于原串的每个长度>2,且结尾两个字符相同的子串而言,将其除了最后两个字符之外反向接在其结尾,都是一个合法解。该解的长度一定是偶数。
于是在SAM上统计一下就可以了……非常容易,O(n)。
别忘了减去长度为1的子串,以及长度为2,且两个字符相等的子串数。
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; typedef long long ll; #define MAXL 100000 #define MAXC 26 int v[2*MAXL+10],__next[2*MAXL+10],first[2*MAXL+10],e; void AddEdge(int U,int V){ v[++e]=V; __next[e]=first[U]; first[U]=e; } char s[MAXL+10];//文本串 int len/*文本串长度*/; struct SAM{ int one_of_endpos[2*MAXL+10]; int n/*状态数0~n-1*/,maxlen[2*MAXL+10],minlen[2*MAXL+10],trans[2*MAXL+10][MAXC],slink[2*MAXL+10]; int new_state(int _maxlen,int _minlen,int _trans[],int _slink){ maxlen[n]=_maxlen; minlen[n]=_minlen; for(int i=0;i<MAXC;++i){ if(_trans==NULL){ trans[n][i]=-1; } else{ trans[n][i]=_trans[i]; } } slink[n]=_slink; return n++; } int add_char(char ch,int u,int pos){ if(u==-1){ return new_state(0,0,NULL,-1); } int c=ch-'a'; int z=new_state(maxlen[u]+1,-1,NULL,-1); one_of_endpos[z]=pos; int v=u; while(v!=-1 && trans[v][c]==-1){ trans[v][c]=z; v=slink[v]; } if(v==-1){//最简单的情况,suffix-path(u->S)上都没有对应字符ch的转移 minlen[z]=1; slink[z]=0; return z; } int x=trans[v][c]; if(maxlen[v]+1==maxlen[x]){//较简单的情况,不用拆分x minlen[z]=maxlen[x]+1; slink[z]=x; return z; } int y=new_state(maxlen[v]+1,-1,trans[x],slink[x]);//最复杂的情况,拆分x slink[y]=slink[x]; minlen[x]=maxlen[y]+1; slink[x]=y; minlen[z]=maxlen[y]+1; slink[z]=y; int w=v; while(w!=-1 && trans[w][c]==x){ trans[w][c]=y; w=slink[w]; } minlen[y]=maxlen[slink[y]]+1; return z; } void dfs(int U){ for(int i=first[U];i;i=__next[i]){ dfs(v[i]); one_of_endpos[U]=one_of_endpos[v[i]]; } } void work_slink_tree(){ for(int i=1;i<n;++i){ AddEdge(slink[i],i); } dfs(0); } }sam; ll ans; bool vis[1001]; int main(){ // freopen("uestc.h.in","r",stdin); scanf("%s",s); len=strlen(s); int U=sam.add_char(0,-1,0); for(int i=0;i<len;++i){ U=sam.add_char(s[i],U,i); } sam.work_slink_tree(); for(int i=0;i<len;++i){ if(!vis[s[i]]){ vis[s[i]]=1; --ans; } } memset(vis,0,sizeof(vis)); for(int i=0;i<len-1;++i){ if(s[i]==s[i+1] && (!vis[s[i]])){ vis[s[i]]=1; --ans; } } for(int i=1;i<sam.n;++i){ ans+=(ll)(sam.maxlen[i]-sam.minlen[i]+1); if(s[sam.one_of_endpos[i]]==s[sam.one_of_endpos[i]-1]){ ans+=(ll)(sam.maxlen[i]-max(2,sam.minlen[i])+1); } } printf("%lld ",ans); return 0; }