问题描述
给定一个字符串,求这个字符串中最长的回文串的长度。
处理原串
首先,回文串有两种形式。
1.长度为奇数。如 aaa,回文中心在字母上。
2.长度为偶数。如 aaaa,回文中心在两个字母之间。
考虑怎样同时计算这两种情况的答案。
在字母之间插入一些特殊字符,例如, aaa 变为 ~#a#a#a#,aaaa 变为 ~#a#a#a#a# .(前面的 ~ 可以使代码实现方便一些,防止找到外面去,现在没用)
这样操作后,可以发现,所有回文串都转化为了第一种情况。
~ 设为新字符串下标为 0 的位置上的字符。
设原字符串长为 (n) ,则新字符串最后一位的下标为 (n imes 2+1) .
算法流程
设 ( ext{mr}) 表示当前触及的最右边的字符的位置,( ext{mid}) 表示包含当前触及的最右边的字符的最大回文串的回文中心的位置。
设数组 (p) , (p_i) 表示以位置 (i) 为对称中心,在 (1 sim ext{mr}) 的范围内的最长的回文串的回文半径。
可以发现,要求的答案即为 (p_i) 的最大值。
从小到大枚举 (i) ,考虑如何更新 (p_i) .
因为 ( ext{mid}) 一定已经更新过,所以 (i) 一定在 ( ext{mid}) 的右边。
故分为两种情况讨论。
1.若 (i) 在 ( ext{mr}) 的左边。
设 (i) 关于 ( ext{mid}) 的对称点为 (j) .(显然, (j) 可以由中点公式直接得到)
下面,称包含当前触及的最右边的字符的最大回文串为大回文串。
(p_j) 对应的回文串为前面的小回文串,还未更新的 (p_i) 对应的回文串为后面的小回文串。
后面的小回文串的右端如果与大回文串的右端重合,为特殊情况,需要将 (p_j) (为什么是 (p_j) 见下面一种情况)与 ( ext{mr}-i+1) (在边界内可能延伸的最大的长度)取 (min) 得到 (p_i),同时,后面的小回文串还可能继续往外延伸。
在后面的小回文串的右端不与大回文串的右端重合的情况下,
考虑把 (p_i) 对应的回文串对称过去,必然能得到 (p_j) 对应的回文串,即 (p_i=p_j) 。
下面证明这一结论。
根据在大回文串中的对称性,
因为后面的小回文串的两端均还有在大回文串内的字符,
所以前面的小回文串的两端也还有在大回文串内的字符。
同样根据在大回文串中的对称性,
又因为前面的小回文串的两端还有在大回文串内的字符,
且前面的小回文串已经是一个极大回文串了,(根据 (p_j) 的定义)
所以前面的小回文串不能再延伸了,
所以后面的小回文串也不能再延伸了。
所以,在此情况下,两字符串全等。
整理为一般的做法,则 (p_i=min(p_{ ext{mid} imes 2-i}, ext{mr}-i+1)) .(其中, ( ext{mid} imes 2-i) 即为 (j))
前一种情况(后面的小回文串的右端与大回文串的右端重合)下需要的往外延伸,暴力即可。记得更新 ( ext{mr}) 和 ( ext{mid}) 。
2.若 (i) 在 ( ext{mr}) 的右边。
直接把 (p_i) 设为 0 或 1 ,同样暴力尝试向外延伸即可。
最后要注意,求得的答案为转化后的串中最大回文串的回文半径,经过多次观察,原问题的答案即为现在求得的答案减去 1 。
复杂度证明
首先,遍历 (i) 的复杂度为 (O(n)) .
在第一种情况的后一种情况(后面的小回文串的右端不与大回文串的右端重合)下,即使尝试,也不会向外延伸,直接得到了 (p_i) .
在其它情况下,得到初始的答案后,每暴力往外延伸一步, ( ext{mr}) 也往右移一步。串长为 (O(n)) 级别的,所以 ( ext{mr}) 移动的次数是 (O(n)) 级别的,所以暴力延伸时尝试成功的次数也只会是 (O(n)) 级别的。
所以,总复杂度为 (O(n)) 的。
代码实现
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
const int N=3e7+10;
char a[N],s[N];
int n,p[N];
int MIN(int x,int y)
{
return x<y?x:y;
}
int MAX(int x,int y)
{
return x>y?x:y;
}
int manacher()
{
int mr=0,mid=0;
for(int i=1;i<n;i++)
{
if(i<mr) p[i]=MIN(p[mid*2-i],mr-i+1);
while(a[i+p[i]]==a[i-p[i]]) p[i]++;
if(i+p[i]-1>mr) mr=i+p[i]-1,mid=i;
//i+p[i]-1为这个串的右端
}
int ans=0;
for(int i=1;i<=n;i++) ans=MAX(ans,p[i]);
return ans-1;
}
int main()
{
scanf("%s",s+1);
n=strlen(s+1);
a[0]='~',a[1]='#';
for(int i=1;i<=n;i++)
a[i*2]=s[i],a[i*2+1]='#';
n=n*2+1;
printf("%d",manacher());
return 0;
}