放在了考试T1
发现70分的DP很水啊,f[i][j]为当前位置是i分配了j个队的方案
我们用前缀和统计,在将i删去,j倒序枚举,就可以删掉一维(也可以滚动数组滚起来)
1 #include<iostream> 2 #include<cstdio> 3 #include<cmath> 4 #include<cstring> 5 #include<string> 6 #include<algorithm> 7 #include<vector> 8 #include<map> 9 #define MAXN 110001 10 #define int long long 11 using namespace std; 12 int f[MAXN],sum[MAXN]; 13 int n,m,K; 14 const int mod=998244353; 15 int ans=0; 16 signed main() 17 { 18 // freopen("text.in","r",stdin); 19 // freopen("a.out","w",stdout); 20 scanf("%lld%lld%lld",&n,&m,&K); 21 m-=n;K-=1; 22 if(m<0||K<0) 23 { 24 printf("0 "); 25 return 0; 26 } 27 f[0]=1;sum[0]=1; 28 for(int j=1;j<=m;++j)sum[j]=sum[j-1]+f[j]; 29 for(int i=1;i<=n;++i) 30 { 31 for(int j=m;j>=0;--j) 32 { 33 if(j-K-1>=0) 34 f[j]=(sum[j]-sum[j-K-1]+mod)%mod; 35 else 36 f[j]=sum[j]%mod; 37 // printf("%lld f[%lld]=%lld ",i,j,f[j]); 38 } 39 sum[0]=f[0]; 40 for(int j=1;j<=m;++j) 41 { 42 sum[j]=(sum[j-1]+f[j]+mod)%mod; 43 } 44 } 45 printf("%lld ",f[m]%mod); 46 } 47 /* 48 3 7 3 49 */
正解是容斥
每次枚举至少有i个不符合条件,很显然满足容斥定理
然后C(n,i)表示在n个城市中任选出i个不符合的
C(m-1-i*K,n-1)显然挡板法,现将i*K这一不符合条件的部分删去,然后在剩余区间中
插入n-1个板子,意味着分成n份,
又因为C(n,i)是选出不符合的i个,然后相乘,就是至少为i时的所有情况
我们发现至少为i的情况包括i+1,i+2.....的情况,
那么我们可以看出i+1的情况可以看成是在至少为i时,确定了i,又选了1个
以至少为一为例,
选了为2的情况为C(2,1)次,(2集合中的两个数都被1集合重复选过)
那么容斥一下至少为2时改为加
至于容斥更严谨的证明,可以去看别人博客,我就不证了.......
1 #include<iostream> 2 #include<cstdio> 3 #include<cmath> 4 #include<cstring> 5 #include<string> 6 #include<algorithm> 7 #include<vector> 8 #include<map> 9 #define MAXN 31000001 10 #define int long long 11 using namespace std; 12 const int mod=998244353; 13 int n,m,K; 14 int ni_c[MAXN],jie[MAXN],ni[MAXN]; 15 int C(int x,int y) 16 { 17 if(y>x)return 0; 18 if(y==0)return 1; 19 return jie[x]*ni_c[y]%mod*ni_c[x-y]%mod; 20 } 21 signed main() 22 { 23 scanf("%lld%lld%lld",&n,&m,&K); 24 int ans=0; 25 jie[1]=1;ni[1]=1;ni_c[1]=1; 26 jie[0]=1;ni[0]=1;ni_c[0]=1; 27 for(int i=2;i<=m+1;++i) 28 { 29 jie[i]=jie[i-1]*i%mod; 30 ni[i]=(mod-mod/i)*ni[mod%i]%mod; 31 ni_c[i]=ni_c[i-1]*ni[i]%mod; 32 } 33 ans=C(m-1,n-1)%mod; 34 int base=n; 35 for(int i=1;i<=n;++i) 36 { 37 if(m-(i*K)<n)break; 38 if((i%2)==0) 39 { 40 ans=(ans+base*C(m-1-(i*K),n-1)%mod+mod)%mod; 41 } 42 else 43 { 44 ans=(ans-base*C(m-1-(i*K),n-1)%mod+mod)%mod; 45 } 46 base=(base*(n-i)%mod)*ni[i+1]%mod; 47 } 48 printf("%lld ",ans%mod); 49 }