JSOI 2016 扭动的字符串
题面描述
给出两个长度为(n)的字符串(A,B)
(S(i,j,k))表示把(A)中的([i,j])和(B)中的([j,k])拼接起来的字符串
问所有回文的(S(i,j,k))或者(A,B)中的回文子串的最长长度
思路
枚举回文串的中心。
可以发现,如果能在当前字符串内扩展就尽量扩展,不能扩展了再尝试和另一个字符串匹配。
对于前者,使用(manacher)算法
对于后者,二分一个长度,用(hash)判断能否匹配。
以上
代码
#include<bits/stdc++.h>
#define pii pair<int,int>
#define mkp(x,y) make_pair(x,y)
using namespace std;
const int sz=2e5+7;
const int p1=998244353;
const int p2=1e8+7;
const int q1=5271314;
const int q2=2374899;
int n,m;
int ans;
int p[2][sz];
char s[2][sz];
char ss[2][sz<<1];
int w[2][sz];
int pre[2][2][sz],suf[2][2][sz];//字符串,哈希,位置
void init(int tp){
int mid=0,r=0;
for(int i=1;i<=m;i++){
if(i<r) p[tp][i]=min(r-i,p[tp][2*mid-i]);
else p[tp][i]=1;
while(ss[tp][i+p[tp][i]]==ss[tp][i-p[tp][i]]) p[tp][i]++;
if(i+p[tp][i]>r) mid=i,r=i+p[tp][i];
}
}
void hash(int tp){
int hs0,hs1;
hs0=hs1=0;
for(int i=1;i<=n;i++){
hs0=(1ll*hs0*q1%p1+s[tp][i])%p1;
hs1=(1ll*hs1*q2%p2+s[tp][i])%p2;
pre[tp][0][i]=hs0;
pre[tp][1][i]=hs1;
}
hs0=hs1=0;
for(int i=n;i>=1;i--){
hs0=(1ll*hs0*q1%p1+s[tp][i])%p1;
hs1=(1ll*hs1*q2%p2+s[tp][i])%p2;
suf[tp][0][i]=hs0;
suf[tp][1][i]=hs1;
}
}
pii gs(int tp,int st,int l){
int sum1=pre[tp][0][st]-1ll*w[0][l]*pre[tp][0][st-l]%p1;
int sum2=pre[tp][1][st]-1ll*w[1][l]*pre[tp][1][st-l]%p2;
if(sum1<0) sum1+=p1;
if(sum2<0) sum2+=p2;
return mkp(sum1,sum2);
}
pii gn(int tp,int st,int l){
int sum1=suf[tp][0][st]-1ll*w[0][l]*suf[tp][0][st+l]%p1;
int sum2=suf[tp][1][st]-1ll*w[1][l]*suf[tp][1][st+l]%p2;
if(sum1<0) sum1+=p1;
if(sum2<0) sum2+=p2;
return mkp(sum1,sum2);
}
int main(){
scanf("%d",&n);
scanf("%s",s[0]+1);
scanf("%s",s[1]+1);
ss[0][0]=ss[1][0]='#';
ss[0][1]=ss[1][1]='|';
for(int i=1;i<=n;i++){
ss[0][2*i]=s[0][i];
ss[1][2*i]=s[1][i];
ss[0][2*i+1]=ss[1][2*i+1]='|';
}
m=2*n+1;
w[0][0]=w[1][0]=1;
for(int i=1;i<=n;i++){
w[0][i]=1ll*w[0][i-1]*q1%p1;
w[1][i]=1ll*w[1][i-1]*q2%p2;
}
init(0),init(1);
hash(0),hash(1);
int L,R,sl,sr,len,l,r,mid;
for(int i=2;i<m;i++){
L=i-p[0][i],R=i+p[0][i];
sl=L/2,sr=R/2-1;
len=p[0][i]-1;
l=0,r=min(sl,n-sr+1);
while(l<r){
mid=(l+r+1)>>1;
if(gs(0,sl,mid)==gn(1,sr,mid)) l=mid;
else r=mid-1;
}
ans=max(ans,len+2*l);
}
for(int i=2;i<m;i++){
L=i-p[1][i],R=i+p[1][i];
sl=L/2+1,sr=R/2;
len=p[1][i]-1;
l=0,r=min(sl,n-sr+1);
while(l<r){
mid=(l+r+1)>>1;
if(gs(0,sl,mid)==gn(1,sr,mid)) l=mid;
else r=mid-1;
}
ans=max(ans,len+2*l);
}
printf("%d
",ans);
}