题目大意:给你一个字符串,让你求出有多少对相交的回文子串
啊啊啊啊降智了,我怎么又忘了正难则反!
求相交会很难搞。把问题转化成求互不相交的回文子串再减一下就行了
先利用$PAM$求出以每个位置为末尾的回文子串数量,这个数量就是此时构造末尾节点在$fail$树中的深度
再把串翻过来,用同样的方法求出每个位置为开头的回文子串数量
对其中一个数组求前缀和,用乘法原理算一下就行了
空间开不下怎么办?以时间换空间,用邻接表存儿子!每次跳儿子都暴力遍历一次邻接表
1 #include <cstdio> 2 #include <cstring> 3 #include <algorithm> 4 #define ll long long 5 #define N1 2000010 6 using namespace std; 7 const int p=51123987; 8 9 template <typename _T> void read(_T &ret) 10 { 11 ret=0; _T fh=1; char c=getchar(); 12 while(c<'0'&&c>'9'){ if(c=='-') fh=-1; c=getchar(); } 13 while(c>='0'&&c<='9'){ ret=ret*10+c-'0'; c=getchar(); } 14 ret=ret*fh; 15 } 16 /*void exgcd(ll a,ll b,ll &x,ll &y) 17 { 18 if(!b){ x=1; y=0; return; } 19 exgcd(b,a%b,x,y); ll t=x; x=y; y=t-a/b*y; 20 }*/ 21 22 int idx(char c){ return c-'a'; } 23 24 struct Edge{ 25 int to[N1],nxt[N1],val[N1],head[N1],cte; 26 void ae(int u,int v,int w) 27 { cte++; to[cte]=v; nxt[cte]=head[u]; val[cte]=w; head[u]=cte; } 28 }e; 29 30 namespace PAM{ 31 int pre[N1],dep[N1],sum[N1],sz[N1],la,tot; 32 void clr() 33 { 34 memset(pre,0,sizeof(pre)); memset(dep,0,sizeof(dep)); 35 memset(sum,0,sizeof(sum)); memset(&e,0,sizeof(e)); 36 } 37 void init(){ la=tot=1; pre[0]=pre[1]=1; dep[1]=-1; } 38 int same(char *str,int p,int i){ return str[i-dep[p]-1]==str[i]; } 39 int trs(int x,int c) 40 { 41 for(int j=e.head[x];j;j=e.nxt[j]) 42 if(e.val[j]==c) return e.to[j]; 43 return 0; 44 } 45 int insert(char *str,int i) 46 { 47 int p=la,np,fp,tp,c=idx(str[i]); 48 while(!same(str,p,i)) p=pre[p]; 49 if(!(tp=trs(p,c))) //!trs[p][c] 50 { 51 np=++tot; 52 dep[np]=dep[p]+2; 53 fp=pre[p]; 54 while(!same(str,fp,i)) fp=pre[fp]; 55 pre[np]=trs(fp,c); //trs[fp][c] 56 e.ae(p,np,c); //trs[p][c]=np; 57 p=np; 58 }else p=tp; 59 la=p; 60 sum[p]=sum[pre[p]]+1; 61 return p; 62 } 63 }; 64 65 int n; 66 char str[N1]; 67 ll lsum[N1],rsum[N1]; 68 69 int main() 70 { 71 int i,j,x; ll ans=0; 72 scanf("%d",&n); 73 scanf("%s",str+1); 74 PAM::init(); 75 for(i=1;i<=n;i++) x=PAM::insert(str,i), lsum[i]=PAM::sum[x];// ans+=lsum[i]; 76 PAM::clr(); PAM::init(); 77 reverse(str+1,str+n+1); 78 for(i=1;i<=n;i++) x=PAM::insert(str,i), rsum[i]=PAM::sum[x], rsum[i]+=rsum[i-1]; //rsum[i]+=rsum[i-1]; 79 if(rsum[n]&1) ans=1ll*((rsum[n]-1)/2%p)*(rsum[n]%p)%p; 80 else ans=1ll*(rsum[n]/2%p)*((rsum[n]-1)%p)%p; 81 for(i=2;i<=n;i++) ans=(ans-1ll*lsum[i-1]*rsum[n-i+1]%p+p)%p; 82 printf("%I64d ",(ans%p+p)%p); 83 return 0; 84 }