【BZOJ3160】万径人踪灭
Description
Input
Output
Sample Input
Sample Output
HINT
题解:自己想出来1A,先撒花~(其实FFT部分挺裸的)
做这道题,第一思路很重要,显然看到这题的第一想法就是ans=总数-不合法(不要问我为什么显然)。因为向这种用补集法的题一般都会给一些很奇葩的限制条件,但是一旦换个角度去想就很水了,好了不多说废话了。
显然,不合法的情况,也就是连续的回文区间的方案数,我们直接上Manacher就搞定了嘛!答案就是所有对称轴的(最长回文串长度+1)/2之和(是的,很显然)
对于不合法的情况,我们发现两串对称的情况跟卷积的形式类似(FFT做多了吧?),但是问题来了,怎么构造出一个卷积,使得它的值就是回文子串的个数呢?
我们发现原串只有a或b,所以思考能不能构造出一种卷积,使得对应位置的值跟下面一样
a*b->0 b*a->0 a*a->1 b*b->1
如果把a看成0,b看成1,这显然是一个异或,然而并没什么卵用。但我们如果把a看成0,b看成1,可以满足只有b*b是1,其他都是0;同理,把a看成1,b看成0,可以满足只有a*a是1,其他都是0,然后我们对这两种情况分别求一次卷积,就能得到:以i为对称中心的最长子序列的回文半径长度。这里注意一下,用于是两个一样的多项式相乘,所以每对字符会被算成两次(单个字符自我对称的除外),所以我们要的回文半径应该是(x+1)/2
然而半径长度并不是方案数,由于每对对称的字符都可以选或不选,所以对答案的贡献就是2^长度-1(因为你不能一个也不选吧?)
好吧,感觉说了这么多又有点说不明白了,所以欢迎提问和hack~
#include <cstdio> #include <cstring> #include <iostream> #include <cmath> #define pi acos(-1.0) #define mod 1000000007 using namespace std; int n,ans; struct cp { double x,y; cp (double a,double b){x=a,y=b;} cp (){} cp operator + (cp a){return cp(x+a.x,y+a.y);} cp operator - (cp a){return cp(x-a.x,y-a.y);} cp operator * (cp a){return cp(x*a.x-y*a.y,x*a.y+y*a.x);} }n1[1<<20],n2[1<<20]; int s[1<<20],ret[1<<20],rl[1<<20],pn[1<<20]; char str[1<<20]; void FFT(cp *a,int len,int f) { int i,j,k,h; cp t; for(i=k=0;i<len;i++) { if(i>k) swap(a[i],a[k]); for(j=(len>>1);(k^=j)<j;j>>=1); } for(h=2;h<=len;h<<=1) { cp wn(cos(f*2*pi/h),sin(f*2*pi/h)); for(j=0;j<len;j+=h) { cp w(1.0,0); for(k=j;k<j+h/2;k++) t=w*a[k+h/2],a[k+h/2]=a[k]-t,a[k]=a[k]+t,w=w*wn; } } } void work(cp *a,cp *b,int len) { FFT(a,len,1),FFT(b,len,1); for(int i=0;i<len;i++) a[i]=a[i]*b[i]; FFT(a,len,-1); for(int i=0;i<len;i++) ret[i]+=(int)(a[i].x/len+0.1); } int main() { scanf("%s",str); int i,mx,pos,len=strlen(str); for(i=0;i<len;i++) s[n++]=0,s[n++]=str[i]-'a'; s[n++]=0; for(mx=-1,i=0;i<n;i++) { if(mx>i) rl[i]=min(mx-i+1,rl[2*pos-i]); else rl[i]=1; for(;i+rl[i]<n&&rl[i]<=i&&s[i+rl[i]]==s[i-rl[i]];rl[i]++); if(mx<i+rl[i]-1) mx=i+rl[i]-1,pos=i; } for(i=0;i<n;i++) ans=(ans-rl[i]/2+mod)%mod; for(len=1;len<2*n;len<<=1); for(i=0;i<len;i++) n1[i]=n2[i]=cp(0,0); for(i=1;i<n;i+=2) n1[i]=n2[i]=cp(s[i],0); work(n1,n2,len); for(i=0;i<len;i++) n1[i]=n2[i]=cp(0,0); for(i=1;i<n;i+=2) n1[i]=n2[i]=cp(s[i]^1,0); work(n1,n2,len); for(pn[0]=i=1;i<=n;i++) pn[i]=(pn[i-1]<<1)%mod; for(i=0;i<n;i++) ans=(ans+pn[(ret[i<<1]+(i&1))>>1]-1)%mod; printf("%d",ans); return 0; }