Solution
实际上是求最大伤害总和。
有一个只要有眼睛就能看出来的结论:能出强化牌就出强化牌,最后剩一张出攻击牌,当然如果强化牌不满(k)个就把强化牌出完剩下出攻击牌。因为强化牌都是大于等于1的正整数,所以带来的效果是至少让伤害翻一倍那么显然尽量出强化牌。(然鹅我可能真的没眼睛,看了20分钟才看到(w_i)都是正整数)
由乘法分配律,强化牌和攻击牌的贡献可以分开来计算。设(F_{i,j})表示所有的 (i)张强化牌里选(j)张打出的情况的贡献 之和,(G_{i,j})表示所有 (i)张攻击牌选(j)张打出的情况的贡献 之和。
那么有:
[ans=sumegin{cases}
F_{i,i} imes G_{m-i,k-i},i<k\
F_{i,k-1} imes G_{m-i,1},igeq k
end{cases}
]
但是F和G不太好算,考虑设(dp)数组辅助。
首先可以把牌从大到小排序,因为我肯定是在能选的牌堆中选最大的那些出掉。
设(f_{i,j})表示选i张强化牌打出,其中最小的那张是第j张的贡献。(g_{i,j})表示选i张攻击牌打出,其中最小的是第j张的贡献。简单dp一下,利用前缀和优化就是( ext O(n^2))的:
[egin{align}
&f_{i,j}=a_j imes sumlimits_{i-1leq kleq j-1}f_{i-1,k}\
&g_{i,j}=b_j imes inom{j-1}{i-1} + sumlimits_{i-1leq kleq j-1}g_{i-1,k}
end{align}
]
求出这个之后F,G就是f,g乘上一个组合数再算算的事情:
[F_{x,y}=sumlimits_{igeq y}f_{y,i} imes inom{n-i}{x-y}\
G_{x,y}=sumlimits_{igeq y}g_{y,i} imes inom{n-i}{x-y}
]
F,G对总共(O(n))个,所以计算答案也是(O(n^2))级别的。
像这种题不太会做的原因是想不到可以给状态加限制,例如此题加一个强制第j个是打出的牌中最小的那个就很好转移了。
另外就是函数C(n,m),就是组合数的函数,一定要记得特判 (n<m) 的情况!!!
Code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read(){//be careful for long long!
register int x=0,f=1;register char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
while(isdigit(ch)){x=x*10+(ch^'0');ch=getchar();}
return f?x:-x;
}
const int N=3e3+10,mod=998244353;
int n,m,k,a[N],b[N],tmp[N],fac[N],ifc[N];
int f[N][N],g[N][N];
inline int power(int base,int n){int ans=1;for(;n;n>>=1,base=1ll*base*base%mod)if(n&1)ans=1ll*ans*base%mod;return ans;}
inline int C(int n,int m){if(n<m)return 0;return 1ll*fac[n]*ifc[m%mod]%mod*ifc[n-m]%mod;}
inline bool cmp(const int &x,const int &y){return x>y;}
inline int F(int x,int y){
int ans=0;
for(int i=y;i<=n;++i)ans=(ans+1ll*f[y][i]*C(n-i,x-y)%mod)%mod;
return ans;
}
inline int G(int x,int y){
int ans=0;
for(int i=y;i<=n;++i)ans=(ans+1ll*g[y][i]*C(n-i,x-y)%mod)%mod;
return ans;
}
int main(){
for(int i=fac[0]=1;i<N;++i)fac[i]=1ll*fac[i-1]*i%mod;
ifc[N-1]=power(fac[N-1],mod-2);
for(int i=N-2;~i;--i)ifc[i]=1ll*ifc[i+1]*(i+1)%mod;
int T=read();
while(T--){
n=read(),m=read(),k=read();
for(int i=1;i<=n;++i)a[i]=read();
for(int i=1;i<=n;++i)b[i]=read();
sort(a+1,a+n+1,cmp),sort(b+1,b+n+1,cmp);
f[0][0]=1;
for(int i=1;i<=n;++i)f[1][i]=a[i],g[1][i]=b[i];
for(int i=2;i<=n;++i){
int sf=f[i-1][i-1],sg=g[i-1][i-1];
for(int j=i;j<=n;++j){
f[i][j]=1ll*a[j]*sf%mod;
g[i][j]=(1ll*b[j]*C(j-1,i-1)%mod+sg)%mod;
sf=(sf+f[i-1][j])%mod;sg=(sg+g[i-1][j])%mod;
}
}
int ans=0;
for(int i=max(m-n,0);i<=min(n,m-1);++i){
if(i<k)ans=(ans+1ll*F(i,i)*G(m-i,k-i)%mod)%mod;
else ans=(ans+1ll*F(i,k-1)*G(m-i,1)%mod)%mod;
}
printf("%d
",ans);
}
return 0;
}