本文版权归ljh2000和博客园共有,欢迎转载,但须保留此声明,并给出原文链接,谢谢合作。
本文作者:ljh2000
作者博客:http://www.cnblogs.com/ljh2000-jump/
转载请注明出处,侵权必究,保留最终解释权!
题目链接:BZOJ3160
正解:FFT+manacher
解题报告:
参考博客:戳这里
题目求的是一个字符串的不连续回文子序列个数。
考虑用所有的回文子序列个数$-$连续回文子序列就是答案。
求连续回文子序列的个数只需要跑一遍$manacher$,然后得到以每个点为对称中心的$p$数组之后,可以直接统计出答案。
回文子序列的个数似乎不好考虑,我们不妨考虑以每个地方(包括间隔)为对称点的回文子序列个数。
我们如果知道了两边对应位置相等的个数有$x$个,根据二项式定理$C(n,1)+C(n,2)+C(n,3)+…+C(n,n)=2^n-1$,所以答案就是$2^x-1$。
而$a$、$b$是彼此独立的,所以我们可以分别考虑$a$和$b$。
我们设出一个多项式,若这一位是$a$那么系数就是$1$,容易发现把这个多项式平方之后,$i$项对应的系数就是以$i$为对称中心的相等的$a$的个数。
因为我一直写的是递归版的$FFT$,然后被卡常了...
拖了一个非递归版的$FFT$就愉快地$AC$了。
//It is made by ljh2000 #include <iostream> #include <cstdlib> #include <cstring> #include <cstdio> #include <cmath> #include <algorithm> #include <ctime> #include <vector> #include <queue> #include <map> #include <set> #include <string> #include <complex> using namespace std; typedef long long LL; typedef complex<double> C; const int MOD = 1000000007; const double pi = acos(-1); const int MAXN = 300011; int n,L,f[MAXN],mx,pos,m,p[MAXN]; char ch[MAXN],s[MAXN]; C a[MAXN],b[MAXN],aa[MAXN],bb[MAXN]; int ans[MAXN],tot,out,er[MAXN],R[MAXN]; //ans[i]表示以i为对称中心的两边的对称字符数量(包含i) inline int getint(){ int w=0,q=0; char c=getchar(); while((c<'0'||c>'9') && c!='-') c=getchar(); if(c=='-') q=1,c=getchar(); while (c>='0'&&c<='9') w=w*10+c-'0',c=getchar(); return q?-w:w; } inline LL fast_pow(LL x,LL y){ LL r=1; while(y>0) { if(y&1) r*=x,r%=MOD; x*=x; x%=MOD; y>>=1; } return r; } inline void fft(C *a,int n,int f){ for(int i=0;i<n;i++) if(i<R[i]) swap(a[i],a[R[i]]);//交换位置 for(int i=1;i<n;i<<=1){//待合并区间长度 C wn(cos(pi/i),sin(f*pi/i)),x,y;//这里就不用再*2了,因为合并后的区间长度是i的两倍 for(int j=0;j<n;j+=i<<1){//起始位置 C w(1,0); for(int k=0;k<i;k++,w*=wn){//第k个 x=a[j+k];y=w*a[j+i+k]; a[j+k]=x+y; a[j+i+k]=x-y; } } } } inline LL manacher(){ pos=0; mx=0; s[0]='%'; s[1]='#'; m=1; for(int i=0;i<n;i++) s[++m]=ch[i],s[++m]='#'; for(int i=1;i<=m;i++) { if(i<mx) p[i]=min(p[2*pos-i],mx-i); else p[i]=1; for(;i+p[i]<=m/*!!!*/ && s[i+p[i]]==s[i-p[i]];p[i]++); if(i+p[i]>mx) { mx=i+p[i]; pos=i; } tot+=p[i]/2; tot%=MOD;//一个回文串的贡献 } return tot; } inline void work(){ scanf("%s",ch); n=strlen(ch); int N=n<<1,ll=0; for(int i=0;i<=N;i++) er[i]=fast_pow(2,i); for(int i=0;i<n;i++) if(ch[i]=='a') a[i]=b[i]=1; for(L=1;L<=N;L<<=1) ll++; for(int i=0;i<L;i++) R[i]=(R[i>>1]>>1)|((i&1)<<(ll-1)); fft(a,L,1); fft(b,L,1); for(int i=0;i<L;i++) a[i]*=b[i]; fft(a,L,-1); for(int i=0;i<N;i++) ans[i]=(int)(a[i].real()/L+0.5); for(int i=0;i<n;i++) if(ch[i]=='b') aa[i]=bb[i]=1; fft(aa,L,1); fft(bb,L,1); for(int i=0;i<L;i++) aa[i]*=bb[i]; fft(aa,L,-1); for(int i=0;i<N;i++) ans[i]+=(int)(aa[i].real()/L+0.5); for(int i=0;i<N;i++) ans[i]=er[(ans[i]+1)/2]-1; for(int i=0;i<N;i++) out+=ans[i],out%=MOD; out-=manacher(); out+=MOD; out%=MOD; printf("%d",out); } int main() { work(); return 0; }