好神的一道计数题呀.
code:
#include <cstdio> #include <algorithm> #include <cstring> #define N 5000003 #define ll long long #define mod 998244353 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int invg[N],dp[N],f[N],fac[N],inv[N]; ll g[N]; int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) if(y&1) tmp=(ll)tmp*x%mod; return tmp; } int C(int x,int y) { return (ll)fac[x]*inv[y]%mod*inv[x-y]%mod; } int INV(int x) { return qpow(x,mod-2); } void solve() { int n,m,l,mi,kth,i,j; scanf("%d%d%d%d",&n,&m,&l,&kth); mi=min(min(n,m),l); if(kth>mi) { printf("0 "); return ; } ll tot=1ll*n*m%mod*l%mod,in=1ll; g[0]=tot%mod; for(i=1;i<=mi;++i) { g[i]=(tot-1ll*(n-i)*(m-i)%mod*(l-i)%mod+mod)%mod; in=in*g[i]%mod; } invg[mi]=qpow(in,mod-2); for(i=mi-1;i>=0;--i) invg[i]=(ll)invg[i+1]*g[i+1]%mod; f[0]=1; for(i=0;i<mi;++i) f[i+1]=(ll)f[i]*(n-i)%mod*(m-i)%mod*(l-i)%mod; for(i=0;i<=mi;++i) dp[i]=(ll)f[i]*invg[i]%mod; int ans=0; for(i=kth;i<=mi;++i) { int d=((i-kth)&1)?(mod-1):1; (ans+=(ll)d*C(i,kth)%mod*dp[i]%mod)%=mod; } printf("%d ",ans); } void init() { fac[0]=1; for(int i=1;i<N;i++) fac[i]=(ll)fac[i-1]*i%mod; inv[N-1]=qpow(fac[N-1],mod-2); for(int i=N-2;i>=0;i--) inv[i]=(ll)inv[i+1]*(i+1)%mod; } int main() { // setIO("input"); init(); int i,j,T; scanf("%d",&T); while(T--) solve(); return 0; }