题目链接:http://acm.csust.edu.cn/problem/2002
CSDN食用链接:https://blog.csdn.net/qq_43906000/article/details/107629093
Description
ppq最喜欢次banana了,可这和本题bing没有什么关系。
对于给定的一个字符串,求有多少对互不重叠的非空回文子串。
设给定一个字符串长度为(len),若两个回文串([l_1,r_1])和([l_2,r_2])满足(0leq l_1leq r_1<l_2leq r_2<len),则它们为一对互不重叠的非空回文子串。
Input
输入仅一行,一个仅由小写英文字母构成的字符串。
字符串长度为(len(1 le len le 1e5))
Output
输出一行,代表答案,请对 ((1e9+7)) 取模。
Sample Input 1
banana
Sample Output 1
25
这道题挺妙的。。。要处理互不重叠的回文子串,一下子可能反应不过来,实际上我们可以考虑用前缀回文和后缀回文进行处理。
比如对于上面的(banana)而言,我们先用自动机处理出它的前缀回文:
for (int i=0; i<len; i++) {
pam.add(s[i]);
if (i==0) {
used[i]=pam.num[pam.last];
continue;
}
used[i]=(1LL*used[i-1]+pam.num[pam.last])%mod;
}
那么我们在将字符串颠倒过来,那么对于后缀回文(a),而言,它所能匹配的回文串肯定是在(banan)里面,那么也就是:
pam.init();
for (int i=len-1; i>=1; i--) {
pam.add(s[i]);
ans=(ans+1LL*used[i-1]*pam.num[pam.last]%mod)%mod;
}
于是此题就愉快地结束了。
以下是AC代码:
#include <bits/stdc++.h>
using namespace std;
const int mac=1e5+10;
const int mod=1e9+7;
char s[mac];
int used[mac];
struct PAM
{
int next[mac][30];//指向串尾当前串两端加上同一个字符构成
int fail[mac];//fail指针,失配后跳转的fail指针指向的节点
int cnt[mac];//表示节点i的本质不同的串的个数(不全,最后count跑一遍才完整)
int num[mac];//以节点i表示的最长回文串的最右端为回文结尾的回文串的个数
int len[mac];//len[i]表示节点i的回文串长度
int S[mac];//存放读入的字符
int last;//指向新添加的一个字符后形成的最长回文串的节点
int n;//表示添加的字符的个数
int p;//表示添加的节点的个数
int newnode(int l){//新建节点
for (int i=0; i<30; i++) next[p][i]=0;
cnt[p]=num[p]=0;
len[p]=l;
return p++;
}
void init(){
p=0; newnode(0); newnode(-1);
last=0; n=0;
S[n]=-1; fail[0]=1;
}
int get_fail(int x){
while (S[n-len[x]-1]!=S[n]) x=fail[x];
return x;
}
void add(int c){
c-='a';
S[++n]=c;
int cur=get_fail(last);//通过上一个回文串找到这个回文串的匹配位置
if (!next[cur][c]){//如果这个串没有出现过,说明了出现了一个新的本质不同的回文串
int now=newnode(len[cur]+2);
fail[now]=next[get_fail(fail[cur])][c];
next[cur][c]=now;
num[now]=num[fail[now]]+1;
}
last=next[cur][c];
cnt[last]++;
}
void count(){
for (int i=p-1; i>=0; i--) cnt[fail[i]]+=cnt[i];
}
}pam;
int main(int argc, char const *argv[])
{
scanf ("%s",s);
pam.init();
int len=strlen(s);
for (int i=0; i<len; i++){
pam.add(s[i]);
if (i==0) {used[i]=pam.num[pam.last]; continue;}
used[i]=(1LL*used[i-1]+pam.num[pam.last])%mod;
}
long long ans=0;
pam.init();
for (int i=len-1; i>=1; i--){
pam.add(s[i]);
ans=(ans+1LL*used[i-1]*pam.num[pam.last]%mod)%mod;
}
printf("%lld
",ans);
return 0;
}