题目链接
(Solution)
观察题目发现恰好出现了(s)次的颜色有(k)种,不太好弄.
所以我们设(a[i])表示为恰好出现了(s)次的颜色有至少(i)种的方案数,然后容斥一下
我们看一看(a[i])怎么求?
这很明显可以一眼看出来
[a[i]=C_m^i*C_n^{is}*(m-i)^{n-is}*frac{(is)!}{{s!}^i}
]
在(m)个颜色中选则(i)个,在(n)个位置选择(i*s)个,将这(i*s)个数全排列(有重复元素),剩下的位置随便选
算完这个以后可以开始容斥求答案,设答案数组为(F[i])
[F[i]=sum_{j=i}^{limit}(-1)^{j-i}*C_j^i*a[j]
]
再将组合数拆开:
[F[i]=sum_{j=i}^{limit}(-1)^{j-i}*frac{j!}{(j-i)!*i!}*a[j]
]
[F[i]*i!=sum_{j=i}^{limit}frac{(-1)^{j-i}}{(j-i)!}*a[j]*j!
]
令(f[i]=frac{(-1)^{i}}{(i)!},g[i]=a[j]*j!)
所以原式为:
[F[i]*i!=sum_{j=i}^{limit}f[j-i]*g[i]
]
然后将(f)数组翻转,原式为:
[F[i]*i!=sum_{j=i}^{limit}f[n-j+i]*g[i]
]
然后发现(j)从(1)开始对答案没有影响,这又是个卷积,又因为有(\%)数,所以直接(ntt)求即可
(Code)
#include<iostream>
#include<cstdlib>
#include<cstdio>
#define int long long
#define rg register
#define file(x) freopen(x".in","r",stdin);freopen(x".out","w",stdout);
using namespace std;
const int mod=1004535809;
const int N=10000001+10;
int read(){
int x=0,f=1;
char c=getchar();
while(c<'0'||c>'9') f=(c=='-')?-1:1,c=getchar();
while(c>='0'&&c<='9') x=x*10+c-48,c=getchar();
return f*x;
}
int f[N],g[N],r[N],limit=1,w[N],inv[N],jc[N];
int ksm(int a,int b){
int ans=1;
while(b){
if(b&1) ans=ans*a%mod;
a=a*a%mod,b>>=1;
}
return ans;
}
void fft(int *a,int opt){
for(int i=0;i<=limit;i++)
if(i<r[i])
swap(a[i],a[r[i]]);
for(int i=1;i<limit;i<<=1){
int w=ksm(3,(mod-1)/(i*2));
if(opt==-1) w=ksm(w,mod-2);
for(int j=0;j<limit;j+=i<<1){
int l=1;
for(int k=j;k<j+i;k++){
int p=l*a[k+i]%mod;
a[k+i]=(a[k]-p+mod)%mod;
a[k]=(a[k]+p)%mod;
l=l*w%mod;
}
}
}
}
int C(int n,int m){
return jc[n]*inv[m]%mod*inv[n-m]%mod;
}
main(){
int n=read(),m=read(),s=read(),l=0,M=min(m,n/s),ans=0;
for(int i=0;i<=m;i++)
w[i]=read();
while(limit<=M<<1)
limit<<=1,l++;
jc[0]=1;
for(int i=1;i<=max(n,m);i++)
jc[i]=jc[i-1]*i%mod;
inv[max(n,m)]=ksm(jc[max(n,m)],mod-2);
for(int i=max(n,m)-1;i>=0;i--)
inv[i]=inv[i+1]*(i+1)%mod;
for(int i=0;i<=M;i++)
g[i]=C(m,i)*C(n,i*s)%mod*jc[s*i]%mod*ksm(inv[s],i)%mod*ksm(m-i,n-i*s)%mod;
for(int i=0;i<=M;i++){
g[i]=g[i]*jc[i]%mod;
f[M-i]=(ksm(-1,i)*inv[i]+mod)%mod;
}
for(int i=0;i<limit;i++)
r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
fft(f,1),fft(g,1);
for(int i=0;i<limit;i++)
f[i]=f[i]*g[i]%mod;
fft(f,-1);
for(int i=0;i<=M;i++)
ans=(ans+f[i+M]*ksm(limit,mod-2)%mod*inv[i]%mod*w[i]%mod)%mod;
printf("%lld",ans);
}