题目链接:
直接求不好求,我们考虑容斥,求出至少有$i$个聚集区间的方案数$ans_{i}$,那么最终答案就是$sumlimits_{i=0}^{n}(-1)^i ans_{i}$
那么现在只需要考虑至少有$i$个聚集区间的方案数,我们枚举这$i$个区间的起始点位置,一共有$C_{n-3i}^{i}$种方案(可以看作是刚开始先将每个区间后三个位置去掉,从剩下$n-3i$个位置中选出$i$个区间起点,然后再在每个起点后面加上三个位置)。
那么剩下的$n-4i$个位置就是随便放这四种学生,假设第$j$种学生放了$a_{j}$个、一共有$num_{j}$个,那么方案数就是$frac{(n-4i)!}{prod_{j=1}^{4}a_{j}!}$。
由此可以构造出这四种学生的生成函数,以第一种学生为例:$sumlimits_{j=0}^{num_{1}-i}frac{x^j}{j!}$
将四个生成函数分别用$NTT$乘在一起然后取$x^{n-4i}$前的系数乘上$(n-4i)!$即可得到$n-4i$个位置随便放的方案数。
#include<set> #include<map> #include<cmath> #include<stack> #include<queue> #include<bitset> #include<cstdio> #include<vector> #include<cstring> #include<iostream> #include<algorithm> using namespace std; const int mod=998244353; int f[3000]; int g[3000]; int inv[2000]; int fac[2000]; int mask; int n,a,b,c,d; int ans; int mn,mx; int quick(int x,int y) { int res=1; while(y) { if(y&1) { res=1ll*res*x%mod; } x=1ll*x*x%mod; y>>=1; } return res; } void NTT(int *a,int len,int opt) { for(int i=0,k=0;i<len;i++) { if(i>k) { swap(a[i],a[k]); } for(int j=len>>1;(k^=j)<j;j>>=1); } for(int i=2;i<=len;i<<=1) { int t=i>>1; int x=quick(3,(mod-1)/i); if(opt==-1) { x=quick(x,mod-2); } for(int j=0;j<len;j+=i) { int w=1; for(int k=j;k<j+t;k++) { int tmp=1ll*a[k+t]*w%mod; a[k+t]=(a[k]-tmp+mod)%mod; a[k]=(a[k]+tmp)%mod; w=1ll*w*x%mod; } } } if(opt==-1) { int x=quick(len,mod-2); for(int i=0;i<len;i++) { a[i]=1ll*a[i]*x%mod; } } } int C(int n,int m) { return 1ll*fac[n]*inv[m]%mod*inv[n-m]%mod; } int solve(int x) { memset(f,0,sizeof(f)); memset(g,0,sizeof(g)); for(int i=0;i<=a-x;i++) { f[i]=inv[i]; } for(int i=0;i<=b-x;i++) { g[i]=inv[i]; } NTT(f,mask,1); NTT(g,mask,1); for(int i=0;i<mask;i++) { f[i]=1ll*f[i]*g[i]%mod; } memset(g,0,sizeof(g)); for(int i=0;i<=c-x;i++) { g[i]=inv[i]; } NTT(g,mask,1); for(int i=0;i<mask;i++) { f[i]=1ll*f[i]*g[i]%mod; } memset(g,0,sizeof(g)); for(int i=0;i<=d-x;i++) { g[i]=inv[i]; } NTT(g,mask,1); for(int i=0;i<mask;i++) { f[i]=1ll*f[i]*g[i]%mod; } NTT(f,mask,-1); return 1ll*f[n-4*x]*fac[n-4*x]%mod*C(n-3*x,x)%mod; } int main() { inv[1]=inv[0]=fac[0]=1; for(int i=1;i<=1000;i++) { fac[i]=1ll*fac[i-1]*i%mod; } for(int i=2;i<=1000;i++) { inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod; } for(int i=1;i<=1000;i++) { inv[i]=1ll*inv[i-1]*inv[i]%mod; } mask=1; scanf("%d%d%d%d%d",&n,&a,&b,&c,&d); mn=min(n/4,min(min(a,b),min(c,d))); mx=max(max(a,b),max(c,d)); while(mask<=(mx<<2)) { mask<<=1; } for(int i=0;i<=mn;i++) { if(i&1) { ans=(ans-solve(i)+mod)%mod; } else { ans=(ans+solve(i))%mod; } } printf("%d",ans); }