Problem
Solution
在考场上当然要学会写暴力,考虑如果手上已经有了(a)张攻击牌和(b)张强化牌:
- 首先强化牌会在攻击牌之前用(废话),其次要将两种牌分别从大往小打,即排个序先(也是废话)
- 要尽量打强化牌,最后再打一张攻击牌(由于每张强化牌至少乘二,所以打一张强化牌一定不比多打一张攻击牌差)
由于(nleq 3000),预估复杂度为(O(n^2)),所以应该可以枚举两种牌的数量
设两个状态,(F[i][j])表示选取(i)张强化牌,打出(j)张的强化效果,(G[i][j])表示攻击牌,由于有(k)的限制,所以分类讨论一下:
- 若当前升级牌数量超过(k),则需要打出(k-1)张升级牌,再出一张最大值的攻击牌
- 若当前升级牌数量少于(k),则需要打出所有升级牌,再从攻击牌中找到最大的一些牌
然后答案就是
[Ans=sum_{i=0}^{k-1}F[i][i]cdot G[m-i][k-i]+sum_{i=k}^{min(n,m)}F(i,k-1)cdot G(m-i,1)
]
接下来考虑如何求出(F,G)数组,由于需要在所有选牌情况里需要贪心,所以一种解决方案是构造辅助数组
考虑对牌从大到小排序,(f[i][j],g[i][j])表示前(i)张牌中选取了(j)张牌,且必选自己的情况中的和
转移方程(可以利用前缀和做到(O(n^2))转移):
(f[i][j]=w_icdot sum_{l=j-1}^{i-1}f[l][j-1] \ g[i][j]=inom {i-1}{j-1}w_i+sum_{l=j-1}^{i-1}g[l][j-1])
然后就可以枚举第(j)张牌是第几张来将(f ightarrow F,g ightarrow G):
(F[i][j]=sum_{l=j}^ninom {n-l}{i-j}f[l][j] \ G[i][j]=sum_{l=j}^ninom {n-l}{i-j}g[l][j])
由于求一个(F,G)是(O(n))的,所以求出所有的(F,G)是(O(n^3))的,但由于我们计算答案的时候只需要用到(O(n))个(F,G),所以每次暴力从(f,g)统计即可
Code
#include <bits/stdc++.h>
using namespace std;
inline void read(int&x){
char c11=getchar();x=0;while(!isdigit(c11))c11=getchar();
while(isdigit(c11))x=x*10+c11-'0',c11=getchar();
}
const int N=3010,p = 998244353;
int fac[N],inv[N],f[N][N],g[N][N];
int w1[N],w2[N],t1[N],t2[N];
int n,m,k;
inline int qpow(int A,int B){
int res(1);while(B){
if(B&1)res=1ll*res*A%p;
A=1ll*A*A%p;B>>=1;
}return res;
}
inline int c(int nn,int mm){return 1ll*fac[nn]*inv[mm]%p*inv[nn-mm]%p;}
inline int cmp(const int&A,const int&B) {return A>B;}
template <typename _Tp> inline int qm(_Tp x){return x<p?x:x-p;}
void prework(){
fac[0]=inv[0]=1;
for(int i=1;i<N;++i)fac[i]=1ll*fac[i-1]*i%p;
inv[N-1]=qpow(fac[N-1],p-2);
for(int i=N-2;i;--i)inv[i]=1ll*inv[i+1]*(i+1)%p;
}
inline int F(int i,int j){
static int res;res=0;
for(int l=n-max(i-j,0);l>=j;--l)
res=qm(res+1ll*c(n-l,i-j)*f[l][j]%p);
return res;
}
inline int G(int i,int j){
static int res;res=0;
for(int l=n-max(i-j,0);l>=j;--l)
res=qm(res+1ll*c(n-l,i-j)*g[l][j]%p);
return res;
}
void init();void work();
int main(){
prework();int T;read(T);
while(T--)init(),work();
return 0;
}
void work(){
int ans=0;
for(int i=min(n,m);~i&&i>=m-n;--i)
if(i<k)ans=qm(ans+1ll*F(i,i)*G(m-i,k-i)%p);
else ans=qm(ans+1ll*F(i,k-1)*G(m-i,1)%p);
printf("%d
",ans);
}
void init(){
read(n),read(m),read(k);f[0][0]=1;
memset(t1,0,sizeof(t1));memset(t2,0,sizeof(t2));
for(int i=1;i<=n;++i)read(w1[i]);sort(w1+1,w1+n+1,cmp);
for(int i=1;i<=n;++i)read(w2[i]);sort(w2+1,w2+n+1,cmp);
for(int i=1;i<=n;++i)f[i][1]=w1[i],g[i][1]=w2[i];
for(int i=1;i<=n;++i)
for(int j=2;j<=i;++j){
t1[j-1]=qm(t1[j-1]+f[i-1][j-1]);
f[i][j]=1ll*w1[i]*t1[j-1]%p;
t2[j-1]=qm(t2[j-1]+g[i-1][j-1]);
g[i][j]=qm(1ll*c(i-1,j-1)*w2[i]%p+t2[j-1]);
}
}