这题72分做法挺显然的(也是我VP的分):
对于n,D<=5000的数据,可以记录f[i][j]表示到第i次随机有j个数字未匹配的方案,直接O(nD)的DP转移即可。
对于D<=300的数据,根据转移系数建立矩阵,跑一遍矩阵快速幂,复杂度O(D3logn),不过要注意卡常数,因为是稀疏矩阵可以判掉无用状态。
对于m较小数据,m=0快速幂,m=1为Dn-A(n,D),m=2暴力讨论一下有没有出现>=1次的值,如果有,唯一出现>=1次的值是出现2次还是3次。
当然还是水平低啊不会正解。正解是生成函数。转化是对的,匹配数>=m就是未匹配的数<=min(D,n-2m),未匹配的数实际上就是出现奇数次的数。一个数出现奇数次的生成函数是:(ex+e-x)/2,偶数次为:(ex-e-x)/2。然后ans=n!(Σ((ex+e-x)/2+y(ex-e-x)/2)D[xn][yk]),其中0<=k<=n-2m,由于我不会用LaTeX,打数学公式太长太慢了,直接写最终式子的结果:ans=(1/2)DΣC(D,i)(2i-D)nΣ(1-y)i(1-y)D-i[yk],其中0<=i<=D,0<=k<=n-2m,然后将式子展开后发现后面的是一个阶乘式,阶乘展开后又是一个卷积形式,再加上mod=998244353,直接NTT处理即可。
#include<bits/stdc++.h> using namespace std; const int N=3e5+7,mod=998244353,inv2=499122177; int D,n,m,nn,ans,fac[N],inv[N],R[N],f[N],A[N],B[N]; int qpow(int a,int b) { int ret=1; while(b) { if(b&1)ret=1ll*ret*a%mod; a=1ll*a*a%mod,b>>=1; } return ret; } void NTT(int*a,int tp) { for(int i=0;i<nn;i++)if(i<R[i])swap(a[i],a[R[i]]); for(int i=1;i<nn;i<<=1) { int wn=qpow(3,mod/(i<<1)); if(tp==-1)wn=qpow(wn,mod-2); for(int j=0;j<nn;j+=i<<1) for(int k=0,w=1;k<i;k++,w=1ll*w*wn%mod) { int x=a[j+k],y=1ll*w*a[i+j+k]%mod; a[j+k]=(x+y)%mod,a[i+j+k]=(x-y+mod)%mod; } } if(tp==1)return; int invn=qpow(nn,mod-2); for(int i=0;i<nn;i++)a[i]=1ll*a[i]*invn%mod; } int C(int a,int b){return 1ll*fac[a]*inv[b]%mod*inv[a-b]%mod;} int main() { scanf("%d%d%d",&D,&n,&m); m=n-2*m; fac[0]=1;for(int i=1;i<=1e5;i++)fac[i]=1ll*fac[i-1]*i%mod; inv[100000]=qpow(fac[100000],mod-2);for(int i=1e5;i;i--)inv[i-1]=1ll*inv[i]*i%mod; if(m>=D){printf("%d",qpow(D,n));return 0;} if(m<=0) { for(int i=-D;i<=D;i++) if((D+i)%2==0)ans=(ans+1ll*qpow(i+mod,n)*C(D,D+i>>1))%mod; ans=1ll*ans*qpow(inv2,D)%mod; printf("%d",ans); return 0; } A[0]=1;for(int i=1;i<=D;i++)A[i]=1ll*C(i-1,m)*(m&1?mod-1:1)%mod; reverse(A,A+D+1); for(int i=0;i<=D;i++)A[i]=1ll*A[i]*qpow(2,i)%mod*inv[i]%mod; for(int i=0;i<=D;i++)B[i]=1ll*inv[i]*(i&1?mod-1:1)%mod; nn=1;int L=0; while(nn<=D*2)nn*=2,L++; for(int i=0;i<nn;i++)R[i]=R[i>>1]>>1|((i&1)<<L-1); NTT(A,1),NTT(B,1); for(int i=0;i<nn;i++)f[i]=1ll*A[i]*B[i]%mod; NTT(f,-1); for(int i=0;i<=D;i++)ans=(ans+1ll*C(D,i)*qpow(mod+2*i-D,n)%mod*f[i]%mod*fac[i])%mod; ans=1ll*ans*qpow(inv2,D)%mod; printf("%d",ans); }