题面链接
sol
好神啊。果然(dp)还是做少了,纪录一下现在的思维吧(QAQ)。
我们首先可以发现期望是骗人的,要不然他乘的是什么xjb玩意。
其实就是要求所有方案的最优方案和。
因为(w_i)是大于1的,所以能强化先强化,再从大往小打攻击牌。
那么我们枚举用了(a)张强化,(b)张攻击。
若(a<k),显然强化牌选完,攻击牌从大到小。
否则,选前(k-1)大的强化牌,再选最大的攻击牌。
我们如何做到最优呢,这里有一个套路,先排序,这样可以保证每一种方案这一定是最优的。
设(f[i][j])表示用了(i)张强化牌最后一张是第(j)张的所有的方案的乘积。
设(g[i][j])表示用了(i)张攻击牌最后一张是第(j)张的所有的方案的和。
注意我们是求所有的方案。
设(F[i][j])表示取了(i)张强化牌,用了(j)张的方案的乘积。
设(G[i][j])表示取了(i)张攻击牌,用了(j)张的方案的和。
那么有
[F[x][y]=sum_{i=1}^nf[y][i]*inom{n-i}{x-y}
]
[G[x][y]=sum_{i=1}^ng[y][i]*inom{n-i}{x-y}
]
一张强化牌都不能要的时候。(F[x][y]=inom{n}{x})
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define gt getchar()
#define ll long long
#define File(s) freopen(s".in","r",stdin),freopen(s".out","w",stdout)
inline int in()
{
int k=0;char ch=gt;
while(ch<'-')ch=gt;
while(ch>'-')k=k*10+ch-'0',ch=gt;
return k;
}
const int N=3005,YL=998244353;
inline int MO(const int &a){return a>=YL?a-YL:a;}
inline int ksm(int a,int k){int r=1;while(k){if(k&1)r=1ll*r*a%YL;a=1ll*a*a%YL,k>>=1;}return r;}
int f[N][N],g[N][N],a[N],b[N],sum[N],fac[N],fnv[N],n,m,k;
inline int C(int x,int y){return y>x?0:1ll*fac[x]*fnv[y]%YL*fnv[x-y]%YL;}
inline int qh(int x,int y)
{
if(y>x)return 0;if(!y)return C(n,x);
int res=0;
for(int i=1;i<=n;++i)
res=MO(res+1ll*f[y][i]*C(n-i,x-y)%YL);
return res;
}
inline int gj(int x,int y)
{
if(y>x)return 0;if(!y)return 0;
int res=0;
for(int i=1;i<=n;++i)
res=MO(res+1ll*g[y][i]*C(n-i,x-y)%YL);
return res;
}
int main()
{
fac[0]=fnv[0]=1;
for(int i=1;i<=3000;++i)fac[i]=1ll*fac[i-1]*i%YL;fnv[3000]=ksm(fac[3000],YL-2);
for(int i=3000;i>=1;--i)fnv[i-1]=1ll*fnv[i]*i%YL;
int t=in();
while(t--)
{
n=in(),m=in(),k=in();
for(int i=1;i<=n;++i)a[i]=in();
for(int i=1;i<=n;++i)b[i]=in();
std::sort(a+1,a+n+1,std::greater<int>());
std::sort(b+1,b+n+1,std::greater<int>());
for(int i=1;i<=n;++i)f[1][i]=a[i],sum[i]=MO(sum[i-1]+a[i]);
for(int i=2;i<=n;++i)
{
for(int j=1;j<=n;++j)f[i][j]=1ll*a[j]*sum[j-1]%YL;
for(int j=1;j<=n;++j)sum[j] =MO(sum[j-1]+f[i][j]);
}
for(int i=1;i<=n;++i)g[1][i]=b[i],sum[i]=MO(sum[i-1]+b[i]);
for(int i=2;i<=n;++i)
{
for(int j=1;j<=n;++j)g[i][j]=MO(1ll*b[j]*C(j-1,i-1)%YL+sum[j-1]);
for(int j=1;j<=n;++j)sum[j] =MO(sum[j-1]+g[i][j]);
}
int ans=0;
for(int a=0,t=std::min(n,m);a<=t;++a)
{
int b=m-a;if(b>n||b<0)continue;
ans=MO(ans+(a<k?1ll*qh(a,a)*gj(b,k-a)%YL:1ll*qh(a,k-1)*gj(b,1)%YL));
}
printf("%d
",ans);
}
return 0;
}