暑假的时候状态不好,现在发现这东西贼水
顺便复兴一下 (FFT) 板子
Description
给定一个 (a,b) 字符串,求非子串回文子序列个数
(|s| le 10^5)
Solution
按照题面概括,答案应该可以写成:回文子序列数 (-) 回文子串个数
后面的随便做一下就行了,比如 (hash+) 二分或 (Manacher)
然后考虑前面的部分:
令 (f_{i}=sumlimits_{j<i} [s_j=s_{i imes 2-j}])
那么以 (i) 为对称半径的答案就是 (2^{f_i}-1)
设 (a=1,b=0) 那么 (1 imes 1=1,0 imes 0=0 imes 1=1 imes 0=0) 卷积上去即可
再令 (a=0,b=1) 做一遍一样的
Code
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define reg register
namespace yspm{
inline int read()
{
int res=0,f=1; char k;
while(!isdigit(k=getchar())) if(k=='-') f=-1;
while(isdigit(k)) res=res*10+k-'0',k=getchar();
return res*f;
}
const int N=2e6+10;
int ans[N],r[N],n,m;
const double pi=acos(-1.0);
struct node{
double x,y;
node(){}
node(double xx,double yy){x=xx,y=yy; return ;}
node operator+(node a){return node(a.x+x,y+a.y);}
node operator-(node b){return node(x-b.x,y-b.y);}
node operator*(node a){return node(x*a.x-y*a.y,x*a.y+y*a.x);}
}A[N];
const int mod=1e9+7;
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int del(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int mul(int x,int y){return x*y-x*y/mod*mod;}
inline int ksm(int x,int y)
{
int res=1;
for(;y;y>>=1,x=mul(x,x)) if(y&1) res=mul(res,x);
return res;
}
inline void fft(node *f,int n,int opt)
{
for(reg int i=0;i<n;++i) if(i<r[i]) swap(f[i],f[r[i]]);
for(reg int p=2;p<=n;p<<=1)
{
int len=p>>1; node tmp(cos(pi/len),opt*sin(pi/len));
for(reg int k=0;k<n;k+=p)
{
node buf(1,0);
for(reg int l=k;l<k+len;++l,buf=buf*tmp)
{
node tt=buf*f[l+len];
f[len+l]=f[l]-tt;
f[l]=f[l]+tt;
}
}
}
if(opt==-1) for(reg int i=0;i<n;++i) f[i].x/=n;
return ;
}
#define ull unsigned long long
struct str{
char s[N];
int n,cnt,rmax,mid,len[N];
inline int calc(char *t)
{
s[0]=s[1]='#'; int res=0;
for(reg int i=1;i<=n;++i) s[i<<1]=t[i],s[i<<1|1]='#';
n=(n+1)<<1; rmax=mid=r[1]=1;
for(reg int i=2;i<n;++i)
{
if(i<rmax) r[i]=min(r[mid]+mid-i,r[(mid<<1)-i]); else r[i]=1;
while(s[i-r[i]]==s[i+r[i]]) ++r[i];
if(i+r[i]>rmax) mid=i,rmax=i+r[i];
res=add(res,r[i]>>1);
}
return res;
}
}T;
char s[N];
signed main()
{
scanf("%s",s+1);
int len=strlen(s+1); T.n=len;
int m=1,sum=0;
while(m<=((len+1)<<1)) m<<=1;
for(reg int i=0;i<m;++i) r[i]=r[i>>1]>>1|((i&1)?(m>>1):0);
for(reg int i=1;i<=len;++i) if(s[i]=='a') A[i].x=1;
fft(A,m,1); for(reg int i=0;i<m;++i) A[i]=A[i]*A[i]; fft(A,m,-1);
for(reg int i=0;i<m;++i) ans[i]+=(int)(A[i].x+0.5);
memset(A,0,sizeof(A));
for(reg int i=1;i<=len;++i) if(s[i]=='b') A[i].x=1;
fft(A,m,1); for(reg int i=0;i<m;++i) A[i]=A[i]*A[i]; fft(A,m,-1);
for(reg int i=0;i<m;++i) ans[i]+=(int)(A[i].x+0.5);
for(reg int i=0;i<m;++i) sum=add(sum,del(ksm(2,(ans[i]+1)>>1),1));
printf("%lld
",del(sum,T.calc(s)));
return 0;
}
}
signed main(){return yspm::main();}