分析
代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int mod = 998244353;
const int g = 3;
int p[40010],inv[40010],G,cc[2100][2100],a[40010],b[40010],c[40010],d[40010],r[40010];
inline int pw(int x,int p){
int res=1;
while(p){
if(p&1)res=1ll*res*x%mod;
x=1ll*x*x%mod;
p>>=1;
}
return res;
}
inline void ntt(int a[],int n,int f){
int i,j,k,now;
for(i=0;i<n;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(k=1;k<n;k<<=1){
if(f==1)now=g;
else now=G;
int wn=pw(now,(mod-1)/(k<<1));
for(i=0;i<n;i+=(k<<1)){
int w=1,p,q;
for(j=0;j<k;j++,w=1ll*w*wn%mod){
p=a[i+j],q=1ll*a[i+j+k]*w%mod;
a[i+j]=(p+q)%mod;
a[i+j+k]=(p-q+mod)%mod;
}
}
}
}
inline int get_sum(int n,int A,int B,int C,int D){
int i,j,k,m=1,len=0;
if(n>A+B+C+D||n<0)return 0;
while(m<((A+B+C+D)<<1))m<<=1,len++;
for(i=0;i<m;i++)r[i]=((r[i>>1]>>1)|((i&1)<<(len-1)));
for(i=0;i<m;i++)a[i]=(i<=A)?inv[i]:0;
for(i=0;i<m;i++)b[i]=(i<=B)?inv[i]:0;
for(i=0;i<m;i++)c[i]=(i<=C)?inv[i]:0;
for(i=0;i<m;i++)d[i]=(i<=D)?inv[i]:0;
ntt(a,m,1),ntt(b,m,1),ntt(c,m,1),ntt(d,m,1);
for(i=0;i<m;i++)a[i]=1ll*a[i]*b[i]%mod*c[i]%mod*d[i]%mod;
ntt(a,m,-1);
return 1ll*p[n]*a[n]%mod*pw(m,mod-2)%mod;
}
signed main(){
int n,A,B,C,D,i,j,k;
G=pw(g,mod-2);
p[0]=1;
for(i=1;i<=2000;i++)p[i]=1ll*p[i-1]*i%mod;
inv[2000]=pw(p[2000],mod-2);
for(i=1999;i>=0;i--)inv[i]=1ll*inv[i+1]*(i+1)%mod;
for(i=0;i<=2000;i++)cc[i][0]=cc[i][i]=1;
for(i=1;i<=2000;i++)
for(j=1;j<i;j++)cc[i][j]=(cc[i-1][j]+cc[i-1][j-1])%mod;
scanf("%lld%lld%lld%lld%lld",&n,&A,&B,&C,&D);
int Ans=0;
for(i=0;i<=n/4;i++){
int res=(i&1)?-1:1;
res*=cc[n-3*i][i];
res=1ll*res*get_sum(n-4*i,A-i,B-i,C-i,D-i)%mod;
Ans=(Ans+res+mod)%mod;
}
printf("%lld
",Ans);
return 0;
}