我们一句话题意:求两个字符串的公共回文子串的数量;
首先对于每个串构造一个回文自动机,然后由PAM的定义可知:对于PAM上从根节点转移方式相同所到达的点代表的回文串是相同的;
这样对于两个PAM同时dfs,每次dfs到的节点的数值(在其原串中的出现数量)相乘,然后累加到答案里;
注意:要从偶原点和奇原点各自跑一遍dfs;
#include <bits/stdc++.h> #define inc(i,a,b) for(register int i=a;i<=b;i++) #define dec(i,a,b) for(register int i=a;i>=b;i--) using namespace std; class node2{ public: char s[50010]; class node1{ public: int link,len; int ch[27]; }pam[50010]; long long f[50010]; int size=0; int root1=size++; int root2=size++; int last=root2; void set(){ pam[root1].link=root2,pam[root1].len=0; pam[root2].link=root2,pam[root2].len=-1; } void add(int to,int pos){ int u=last; while(s[pos-pam[u].len-1]!=s[pos]) u=pam[u].link; if(!pam[u].ch[to]){ int neww=size++; pam[neww].len=pam[u].len+2; int v=pam[u].link; while(s[pos-pam[v].len-1]!=s[pos]) v=pam[v].link; pam[neww].link=pam[v].ch[to]; pam[u].ch[to]=neww; } last=pam[u].ch[to]; f[last]++; } }PAM1,PAM2; long long ans=0; void dfs(int x,int y){ if(x+y>2) ans=(ans+PAM1.f[x]*PAM2.f[y]); inc(i,0,25){ if(PAM1.pam[x].ch[i]&&PAM2.pam[y].ch[i]){ dfs(PAM1.pam[x].ch[i],PAM2.pam[y].ch[i]); } } } int main() { scanf("%s %s",PAM1.s+1,PAM2.s+1); int n=strlen(PAM1.s+1),m=strlen(PAM2.s+1); PAM1.set(); PAM2.set(); inc(i,1,n) PAM1.add(PAM1.s[i]-'A',i); inc(i,1,m) PAM2.add(PAM2.s[i]-'A',i); dec(i,PAM1.size-1,1) PAM1.f[PAM1.pam[i].link]+=PAM1.f[i]; dec(i,PAM2.size-1,1) PAM2.f[PAM2.pam[i].link]+=PAM2.f[i]; dfs(1,1); dfs(0,0); cout<<ans; } /* PUPPY PUPPUP */