题意
定义串 (A_0= exttt{a}),(A_i=A_{i-1}+overline{A_{i-1}})(其中上划线代表 a b 互换),考虑串 (A_infty[0..n-1])。给定 (k-1) 次多项式 (f),求 (sum_{i=0}^{n-1}[A_infty[i]= exttt{b}]f(i))。(nleq 2^{500000}),(kleq 500)。
题解
我们首先考虑 (O(k^2log n)) 的做法:我们维护多项式 (F_{i, exttt{a}/ exttt{b}}(x)) 代表把字符串 (A_i) 平移到 (x) 处后,( exttt{a}/ exttt{b}) 的贡献。
我们有 (F_{i, exttt{a}}(x)=F_{i-1, exttt{a}}(x)+F_{i-1, exttt{b}}(x+2^{i-1}))((F_{i, exttt{b}}) 同理),处理系数平移需要 (O(k^2))(二项式定理展开),一共要递推到 (F_{log n, exttt{a}/ exttt{b}})。
如果你在调试时输出了多项式,你可能会发现在 (igeq k) 时,(F_{i, exttt{a}}=F_{i, exttt{b}})。证明可以参见官方题解。(考试时我输出了调试信息,但是没发现结论)
知道这个结论后,我们可以只处理到 (F_{k-1, exttt{a}/ exttt{b}}),剩下的部分借助自然数幂和计算。复杂度 (O(k^3+log n))。
#include<bits/stdc++.h>
using namespace std;
int getint(){
int ans=0;
char c=getchar();
while(!isdigit(c))c=getchar();
while(isdigit(c)){
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
const int K=503,N=500010,mod=1e9+7;
int f[2][K],g[2][K];
int c[K][K];
char s[N];int n;
int p2[N];
int p[N],q[N];
int sp[N];
int qpow(int x,int y){
int ans=1;
while(y){
if(y&1)ans=ans*1ll*x%mod;
x=x*1ll*x%mod;
y>>=1;
}
return ans;
}
int main(){
freopen("angry.in","r",stdin);
freopen("angry.out","w",stdout);
scanf("%s",s+1);
n=strlen(s+1);
int k=getint();
if(n<k){
for(int i=n;i;i--)s[k-n+i]=s[i];
for(int i=1;i<=k-n;i++)s[i]='0';
n=k;
}
for(int i=0;i<k;i++)f[0][i]=getint();
for(int i=c[0][0]=1;i<=k;i++)
for(int j=c[i][0]=1;j<=k;j++)
c[i][j]=(c[i-1][j]+c[i-1][j-1])%mod;
p2[0]=1;for(int i=1;i<=n;i++)p2[i]=p2[i-1]<<1,p2[i]>=mod&&(p2[i]-=mod);
for(int i=1;i<=n;i++){
q[n-i]=q[n-i+1]^(s[i]-'0');
p[n-i]=(p[n-i+1]+(s[i]-'0')*p2[n-i])%mod;
}
int ans=0;
int m=p[0];
// cerr<<"> "<<m<<endl;
sp[0]=m;
for(int i=1;i<k;i++){
sp[i]=qpow(m,i+1);
for(int j=0;j<i;j++)
sp[i]=(sp[i]-c[i+1][j]*1ll*sp[j]%mod+mod)%mod;
sp[i]=sp[i]*1ll*qpow(i+1,mod-2)%mod;
// cerr<<"> "<<sp[i]<<endl;
}
for(int i=0;i<k;i++)ans=(ans+f[0][i]*1ll*sp[i])%mod;
for(int i=0;i<k;i++){
// for(int i=0;i<k;i++)cerr<<">"<<f[0][i];cerr<<endl;
// for(int i=0;i<k;i++)cerr<<">"<<f[1][i];cerr<<endl;cerr<<endl;
int q=::q[i],p=::p[i+1];
if(s[n-i]-'0'){
int v=0;
for(int j=0;j<k;j++)
v=(v*1ll*p+f[q][k-j-1])%mod;
ans=(ans+v)%mod;
v=0;
for(int j=0;j<k;j++)
v=(v*1ll*p+f[q^1][k-j-1])%mod;
ans=(ans+mod-v)%mod;
}
memcpy(g,f,sizeof(g));
int e=p2[i];
for(int j=0;j<k;j++){
int pe=1;
for(int l=j;l>=0;l--,pe=pe*1ll*e%mod){
int t=pe*1ll*c[j][l]%mod;
g[1][l]=(g[1][l]+f[0][j]*1ll*t)%mod;
g[0][l]=(g[0][l]+f[1][j]*1ll*t)%mod;
}
}
memcpy(f,g,sizeof(f));
}
ans=(mod+1ll)/2*ans%mod;
cout<<ans;
}