[AGC019E]Shuffle and Swap
题目大意:
给出两个长度为(n(nle10000))的(01)串(A_{1sim n})和(B_{1sim n})。两个串均有(k)个'1'
。令(a_{1sim k})和(b_{1sim k})分别表示(A)和(B)中所有'1'
出现的位置。将(a)和(b)等概率随机排列,按(1sim k)的顺序交换(A_{a_i})和(A_{b_i})。令(P)表示操作完成后(A)与(B)相等的概率,求(P imes(k!)^2)在模(998244353)意义下的值。
思路:
将(k)个'1'
分为两类:(S_1)和(S_2)。(S_1)表示(A)和(B)中均为'1'
,有(s_1)个;(S_2)表示(A)和(B)中只有一个为'1'
,有(s_2)个。用(f[i][j])表示(S_1)交换完(i)个,(S_2)交换完(j)个后的方案数。
显然,初始情况下,(S_1)可以交换(2)次,(S_2)可以交换(1)次。方便起见,假设我们的(A)和(B)分别为:
110011
111100
对于(i=0)的初始情况,由于(S_2)之间任意配对都可以使得(A=B),因此有((j!)^2)种方案。
若我们交换一对(S_2),如(4)和(5),那么交换后,可以交换(1)次的(S_2)减少了一对。有转移(f[i][j]+=f[i][j-1] imes j^2)。
若我们交换一个(S_1)和一个(S_2),如(2)和(6),那么交换后,原来那个(S_2)已经不能再交换了,而原来可以交换(2)次的(S_1)现在只能交换一次,可以算入(S_2)中,也就是相当于减少了一个(S_1)。有转移(f[i][j]+=f[i-1][j] imes i imes j)。
最后统计答案时,对于(S_1)没有用完的情况,无论如何配对都能满足(A=B)。因此(f[s_1-i][s_2])对答案的贡献为(f[s_1-i][s_2] imes(i!)^2 imesinom{s_1}i imesinom ki)。
时间复杂度(mathcal O(n^2))。
源代码:
#include<cstdio>
#include<cstring>
using int64=long long;
constexpr int N=10001,mod=998244353;
char a[N],b[N];
int fact[N],factinv[N],f[N][N];
void exgcd(const int &a,const int &b,int &x,int &y) {
if(!b) {
x=1,y=0;
return;
}
exgcd(b,a%b,y,x);
y-=a/b*x;
}
inline int inv(const int &x) {
int ret,tmp;
exgcd(x,mod,ret,tmp);
return (ret%mod+mod)%mod;
}
inline int C(const int &n,const int &m) {
if(n<m||n<0||m<0) return 0;
return (int64)fact[n]*factinv[m]%mod*factinv[n-m]%mod;
}
int main() {
scanf("%s%s",a,b);
const int n=strlen(a);
for(register int i=fact[0]=1;i<=n;i++) {
fact[i]=(int64)fact[i-1]*i%mod;
}
factinv[n]=inv(fact[n]);
for(register int i=n;i;i--) {
factinv[i-1]=(int64)factinv[i]*i%mod;
}
int s1=0,s2=0;
for(register int i=0;i<n;i++) {
if(a[i]=='1'&&b[i]=='1') s1++;
if(a[i]=='1'&&b[i]=='0') s2++;
}
for(register int i=0;i<=s2;i++) {
f[0][i]=(int64)fact[i]*fact[i]%mod;
}
for(register int i=1;i<=s1;i++) {
for(register int j=1;j<=s2;j++) {
f[i][j]=((int64)f[i-1][j]*i%mod*j%mod+(int64)f[i][j-1]*j%mod*j%mod)%mod;
}
}
int ans=0;
for(register int i=0;i<=s1;i++) {
(ans+=(int64)f[s1-i][s2]*fact[i]%mod*fact[i]%mod*C(s1,i)%mod*C(s1+s2,i)%mod)%=mod;
}
printf("%d
",ans);
return 0;
}