我竟然能做出九老师的组合计数题,尽管这题很呆
我们先考虑一个简单的问题,如果给定你要选出来的(m)张卡牌,如何做到攻击伤害最高
非常显然,因为保证了强化牌上的数字大于(1),所以我们优先选择那些强化牌,毕竟最小的一张只能翻两倍的强化牌都要比选择攻击牌好;选完强化牌之后剩下的攻击牌自然是越大越好
当然了,我们也不可能选出(k)张强化牌来,这样什么伤害都造不成,所以如果强化牌数量大于(k),我们就选择前(k-1)大的强化牌,配上最大的攻击牌
我们发现给定的强化牌和攻击牌都是无序的,我们先排序再说
之后我们就可以大力(dp)了
设(dp_{i,j})表示前(i)张强化牌里选出(j)张所有强化牌乘积的和,(f_{i,j})表示前(i)张强化牌里选择了(j)张且第(i)张一定被选择的乘积和
显然有转移
[dp_{i,j}=dp_{i-1,j}+dp_{i-1,j-1} imes a_i
]
[f_{i,j}=dp_{i-1,j-1} imes a_i
]
攻击牌这边我们设(h_{i,j})表示前(i)张攻击牌里选择了(j)张的所有方案的和,(g_{i,j})表示强迫选择第(i)张
也有转移
[h_{i,j}=h_{i-1,j}+h_{i-1,j-1}+C_{i-1}^{j-1} imes b_i
]
[g_{i,j}=h_{i-1,j-1}+C_{i-1}^{j-1} imes b_i
]
现在我们考虑枚举选择了多少张强化牌
设选择了(i)张强化牌,自然也就需要选择(m-i)张攻击牌
如果(i<k),那么这些强化牌都是要直接用的,选择的(m-i)张攻击牌里能被用于打出的(k)张的也就只有(k-i)张,而剩下的(m-i-(k-i)=m-k)张牌我们在后面随便选一下
于是答案就是
[sum_{j=0}^nf_{j,i} imes sum_{j=0}^ng_{j,k-i}C_{n-j}^{m-k}
]
如果(i>=k),那么我们也只能选择(k-1)张强化牌,剩下的多选出来的(i-k+1)张我们还是从后面的位置里选出来,实际上需要的攻击牌也就只有(1)张,需要额外多选的攻击牌也就只有(m-i-1)张
答案就是
[sum_{j=0}^nf_{j,k-1} C_{n-j}^{i-k+1} imes sum_{j=0}^ng_{j,1}C_{n-j}^{m-i-1}
]
代码
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
inline int read() {
char c=getchar();int x=0;while(c<'0'||x>'9') c=getchar();
while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int mod=998244353;
const int maxn=3005;
int c[maxn][maxn];
int dp[maxn][maxn],f[maxn][maxn];
int g[maxn][maxn],h[maxn][maxn];
int T,n,m,k;
int a[maxn],b[maxn];
inline int C(int n,int m) {if(m>n) return 0;return c[n][m];}
inline int cmp(int A,int B) {return A>B;}
int main() {
T=read();int M=maxn-5;
for(re int i=0;i<=M;i++) c[i][0]=c[i][i]=1;
for(re int i=1;i<=M;i++)
for(re int j=1;j<i;j++)
c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
while(T--) {
n=read(),m=read(),k=read();
for(re int i=1;i<=n;i++) a[i]=read();
for(re int i=1;i<=n;i++) b[i]=read();
std::sort(a+1,a+n+1,cmp);std::sort(b+1,b+n+1,cmp);
dp[0][0]=1;
for(re int i=1;i<=n;i++)
for(re int j=0;j<=i;j++) {
dp[i][j]=dp[i-1][j];
if(j) dp[i][j]=(dp[i][j]+1ll*dp[i-1][j-1]*a[i]%mod)%mod;
}
f[0][0]=1;
for(re int i=1;i<=n;i++)
for(re int j=1;j<=i;j++)
f[i][j]=1ll*dp[i-1][j-1]*a[i]%mod;
h[0][0]=0;
for(re int i=1;i<=n;i++)
for(re int j=0;j<=i;j++) {
h[i][j]=h[i-1][j];
if(j) h[i][j]=(h[i][j]+1ll*C(i-1,j-1)*b[i]%mod+h[i-1][j-1])%mod;
}
for(re int i=1;i<=n;i++)
for(re int j=1;j<=i;j++)
g[i][j]=(1ll*C(i-1,j-1)*b[i]%mod+h[i-1][j-1])%mod;
int ans=0;
for(re int i=0;i<m;i++) {
int now=0;
for(re int j=0;j<=n;j++) {
if(i<=k-1) now=(now+f[j][i])%mod;
else now=(now+1ll*f[j][k-1]*C(n-j,i-k+1)%mod)%mod;
}
int tot=0,res=0;
if(i<=k-1) res=k-i;else res=1;
for(re int j=0;j<=n;j++)
tot=(tot+1ll*g[j][res]*C(n-j,m-i-res)%mod)%mod;
ans=(ans+1ll*now*tot%mod)%mod;
}
printf("%d
",ans);
}
return 0;
}