题意:
在一个长度为n的只含a,b的字符串中选取一个子序列,使得:
- 位置和字符都关于某条对称轴对称。
- 不能是连续的一段。
求方案数对$10^{9}+7$取模的值。
$nleq 10^5$。
题解:
首先答案可以用回文子序列个数减回文子串个数得到,回文子串可以Hash+二分求出。
考虑怎么求回文子序列,先枚举对称轴x,统计出满足$str_{x-i}=str{x+i}$的i的个数cnt,那么答案就是$2^{cnt}-1$。
注意到位置$i,j$会给对称轴$frac{i+j}{2}$产生贡献,那么就有一个套路的做法:(以a为例,b同理)
构造多项式$f_i =sum limits_{i=1}^{n}{[str_i = a]x^{i}}$,再令$F=f*f$,那么$lceil frac{F_i}{2} ceil$就是对称轴为$frac{i}{2}$时满足要求的a的对数。
注意到每项的系数不会超过ntt模数,于是ntt做一下就行了,复杂度$O(nlog{n})$。
套路:
- 位置$i,j$对位置$i+j$产生贡献$ ightarrow$构造多项式求解。
代码:
#include<bits/stdc++.h> #define maxn 600005 #define maxm 500005 #define inf 0x7fffffff #define ll long long #define Mod 1000000007 #define mod 998244353 #define g 3 #define rint register ll #define debug(x) cerr<<#x<<": "<<x<<endl #define fgx cerr<<"--------------"<<endl #define dgx cerr<<"=============="<<endl using namespace std; char str[maxn]; ll ind[maxn],pre[maxn],suf[maxn]; struct poly{ ll a[maxn],n; inline void clear(){memset(a,0,sizeof(a)),n=0;} inline ll pw(ll a,ll b){ll r=1;while(b)r=(b&1)?r*a%mod:r,a=a*a%mod,b>>=1;return r;} inline void ntt(ll op){ for(ll i=0;i<n;i++) if(i>ind[i]) swap(a[i],a[ind[i]]); for(ll l=1;l<=n;l<<=1){ ll p=pw(g,(mod-1)/l); if(op==-1) p=pw(p,mod-2); for(ll i=0;i<n;i+=l) for(ll j=i,w=1;j<i+(l>>1);j++,w=w*p%mod){ ll x=a[j],y=w*a[j+(l>>1)]%mod; a[j]=(x+y)%mod,a[j+(l>>1)]=(x-y+mod)%mod; } } if(op==-1){ ll inv=pw(n,mod-2); for(ll i=0;i<n;i++) a[i]=a[i]*inv%mod; } } inline poly operator*(const poly b)const{ poly res; res.clear(),res.n=max(n,b.n); for(ll i=0;i<res.n;i++) res.a[i]=a[i]*b.a[i]%mod; return res; } }; inline ll read(){ ll x=0,f=1; char c=getchar(); for(;!isdigit(c);c=getchar()) if(c=='-') f=-1; for(;isdigit(c);c=getchar()) x=x*10+c-'0'; return x*f; } inline ll pw(ll a,ll b,ll M){ll r=1;while(b)r=(b&1)?r*a%M:r,a=a*a%M,b>>=1;return r;} inline ll calc(ll n){ for(ll i=1;i<=n;i++) pre[i]=(pre[i-1]*2%mod+(str[i]=='b'))%mod; for(ll i=n;i>=1;i--) suf[i]=(suf[i+1]*2%mod+(str[i]=='b'))%mod; ll ans=0; for(ll i=1;i<=n*2;i++){ ll px=i/2,py=i-px,l=1,r=min(px,n-py+1),res=0; while(l<=r){ ll mid=l+r>>1,w=pw(2,mid,mod); ll t1=(suf[px-mid+1]-suf[px+1]*w%mod+mod)%mod; ll t2=(pre[py+mid-1]-pre[py-1]*w%mod+mod)%mod; if(t1==t2) res=mid,l=mid+1; else r=mid-1; } ans=(ans+res)%Mod; } return ans; } int main(){ scanf("%s",str+1); ll n=strlen(str+1),ans=0; poly A,B; A.clear(),B.clear(); for(ll i=1;i<=n;i++){ if(str[i]=='a') A.a[i]=1,B.a[i]=0; else B.a[i]=1,A.a[i]=0; } ll m=1; while(m<=n*2) m<<=1; for(ll i=0;i<m;i++) ind[i]=(i&1)?((ind[i>>1]>>1)|(m>>1)):(ind[i>>1]>>1); A.n=B.n=m; A.ntt(1),B.ntt(1); poly C=(A*A),D=(B*B); C.ntt(-1),D.ntt(-1); for(ll i=0;i<m;i++){ ll t1=C.a[i]-C.a[i]/2,t2=D.a[i]-D.a[i]/2; ans=(ans+pw(2,t1+t2,Mod)+Mod-1)%Mod; } printf("%lld ",(ans-calc(n)+Mod)%Mod); return 0; }