题目链接:E - Shuffle and Swap
题目大意:洛谷
题解:这一道题有两个做法。
Solution 1
考虑有 (n) 个位置上 (A,B) 均为 1
,设为位置 A ,有 (2m) 个位置上 (A) 为 1
, 或 (B) 为 1
,设为位置 B(不包含位置 A),那么我们的目标就是让最后位置 B 的个数减少为 0,那么我们每一次交换位置 A 中的两个数并不会是位置 B 的个数发生变化,所以不管它,如果我们交换位置 B 中的两个数则会使位置 B 的个数减 1,这种这种情况的方案数是 (m^2) ,如果我们交换一个位置 A 上的数和一个位置 B 上的数,则会使位置 A 的个数减 1 ,位置 B 的个数不变。
因此,我们的转移方程就是 (f_{n,m}= f_{n,m-1} imes m^2+f_{n-1,m} imes n imes m) 。
最后我们考虑位置 A 还剩余的方案数,(ans =sum_{i=0}^{n} f_{n-i,m} imes (i!)^2 imes inom{n}{i} imes inom{n+m}{i})。
时间复杂度和空间复杂度均为 (O(n^2))。
Solution 2
考虑一个重要的转化:题目中的 ((k!)^2) 中方案我们可以分成两个步骤,第一个步骤是给每一个 (A) 中的 1
匹配一个 (B) 中的 1
,第二个步骤是给这些情况重新排列。
现在我们更改一下 (n,m) 的定义,令 (n) 表示 (A,B) 上有一个位置为 1
的方案数, (m) 的意义同 Solution 1 中的不变。
考虑步骤一,那么如果我们对于 (A) 中的 1
的位置,向和它匹配的 (B) 中的位置连一条边,那么我们会发现,整张图被我们分成了若干条链和若干个环,链的个数恰好是 (m) 个,并且,链的选择是有顺序要求的,即必须从链首选到链尾,而环的选择则没有要求了,接下来我们考虑使用生成函数来表示这个东西,因为两条链或者两个环在组合的时候是需要乘上组合数的,所以考虑用指数型生成函数来解决。
链的指数型生成函数:(假设起点已经确定,所以在结束之后还需要乘上(m!)到答案中, ([x^i]F(x)) 表示在链的端点之间有 (i) 个点的方案数。)
环的指数型生成函数:(([x^i]G(x))表示 (i) 个点的环的方案数。)
所以我们需要将环和链组合起来,因为链的组合是有序的,而环的组合是无序的,所以最后的结果就是:
然后把函数带进去展开得到:
然后就可以直接计算答案了,时间复杂度 (O(nlog n)),空间复杂度 (O(n)),
Solution 1 的代码:
#include <cstdio>
int quick_power(int a,int b,int Mod){
int ans=1;
while(b){
if(b&1){
ans=1ll*ans*a%Mod;
}
b>>=1;
a=1ll*a*a%Mod;
}
return ans;
}
const int Maxn=10000;
const int Mod=998244353;
int f[Maxn+5][Maxn+5];
int n,k;
char a[Maxn+5],b[Maxn+5];
int s_1,s_2;
int frac[Maxn+5],inv_f[Maxn+5];
void init(){
frac[0]=1;
for(int i=1;i<=Maxn;i++){
frac[i]=1ll*frac[i-1]*i%Mod;
}
inv_f[Maxn]=quick_power(frac[Maxn],Mod-2,Mod);
for(int i=Maxn-1;i>=0;i--){
inv_f[i]=1ll*inv_f[i+1]*(i+1)%Mod;
}
}
int C(int n,int m){
return 1ll*frac[n]*inv_f[m]%Mod*inv_f[n-m]%Mod;
}
int main(){
init();
scanf("%s",a+1);
scanf("%s",b+1);
while(a[++n]!=' ');
n--;
for(int i=1;i<=n;i++){
if(a[i]=='1'){
k++;
}
if(a[i]=='1'&&b[i]=='1'){
s_1++;
}
else if(a[i]=='1'){
s_2++;
}
}
f[0][0]=1;
for(int i=0;i<=s_1;i++){
for(int j=1;j<=s_2;j++){
if(i==0&&j==0){
continue;
}
f[i][j]=(f[i][j]+1ll*f[i][j-1]*j%Mod*j)%Mod;
if(i>0){
f[i][j]=(f[i][j]+1ll*f[i-1][j]*i%Mod*j)%Mod;
}
}
}
int ans=0;
for(int i=0;i<=s_1;i++){
ans=(ans+1ll*f[s_1-i][s_2]*frac[i]%Mod*frac[i]%Mod*C(s_1,i)%Mod*C(k,i))%Mod;
}
printf("%d
",ans);
return 0;
}
Solution 2 的代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
int quick_power(int a,int b,int Mod){
int ans=1;
while(b){
if(b&1){
ans=1ll*ans*a%Mod;
}
b>>=1;
a=1ll*a*a%Mod;
}
return ans;
}
const int Maxn=40000;
const int G=3;
const int Mod=998244353;
int n,m,len;
char a[Maxn+5],b[Maxn+5];
void NTT(int *a,int flag,int n){
static int R[Maxn+5];
int len=1,L=0;
while(len<n){
len<<=1;
L++;
}
for(int i=0;i<len;i++){
R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
}
for(int i=0;i<len;i++){
if(i<R[i]){
swap(a[i],a[R[i]]);
}
}
for(int j=1;j<len;j<<=1){
int T=quick_power(G,(Mod-1)/(j<<1),Mod);
for(int k=0;k<len;k+=(j<<1)){
for(int l=0,t=1;l<j;l++,t=1ll*t*T%Mod){
int Nx=a[k+l],Ny=1ll*t*a[k+l+j]%Mod;
a[k+l]=(Nx+Ny)%Mod;
a[k+l+j]=(Nx-Ny+Mod)%Mod;
}
}
}
if(flag==-1){
reverse(a+1,a+len);
for(int i=0,t=quick_power(len,Mod-2,Mod);i<len;i++){
a[i]=1ll*a[i]*t%Mod;
}
}
}
void find_inv(int *a,int *b,int len){
static int c[Maxn+5],d[Maxn+5];
if(len==1){
b[0]=quick_power(a[0],Mod-2,Mod);
return;
}
find_inv(a,b,len>>1);
for(int i=0;i<len;i++){
c[i]=a[i];
d[i]=b[i];
}
for(int i=len;i<(len<<1);i++){
c[i]=d[i]=0;
}
NTT(c,1,len<<1);
NTT(d,1,len<<1);
for(int i=0;i<(len<<1);i++){
d[i]=1ll*d[i]*d[i]%Mod*c[i]%Mod;
}
NTT(d,-1,len<<1);
for(int i=0;i<len;i++){
b[i]=((b[i]<<1)%Mod-d[i]+Mod)%Mod;
}
for(int i=len;i<(len<<1);i++){
b[i]=0;
}
}
void find_dev(int *a,int len){
for(int i=0;i<len;i++){
a[i]=1ll*(i+1)*a[i+1]%Mod;
}
a[len-1]=0;
}
void find_dev_inv(int *a,int len){
for(int i=len-1;i>0;i--){
a[i]=1ll*quick_power(i,Mod-2,Mod)*a[i-1]%Mod;
}
a[0]=0;
}
void find_ln(int *a,int *b,int n){
static int c[Maxn+5];
for(int i=0;i<n;i++){
c[i]=a[i];
}
find_dev(c,n);
int len=1;
while(len<n){
len<<=1;
}
find_inv(a,b,len);
for(int i=n;i<len;i++){
b[i]=0;
}
for(int i=len;i<(len<<1);i++){
b[i]=c[i]=0;
}
NTT(b,1,len<<1);
NTT(c,1,len<<1);
for(int i=0;i<(len<<1);i++){
b[i]=1ll*b[i]*c[i]%Mod;
}
NTT(b,-1,len<<1);
find_dev_inv(b,len<<1);
for(int i=n;i<(len<<1);i++){
b[i]=0;
}
}
void find_exp(int *a,int *b,int len){
static int c[Maxn+5];
if(len==1){
b[0]=1;
return;
}
find_exp(a,b,len>>1);
find_ln(b,c,len);
c[0]=(a[0]+1-c[0]+Mod)%Mod;
for(int i=1;i<len;i++){
c[i]=(a[i]-c[i]+Mod)%Mod;
}
for(int i=len;i<(len<<1);i++){
b[i]=c[i]=0;
}
NTT(b,1,len<<1);
NTT(c,1,len<<1);
for(int i=0;i<(len<<1);i++){
b[i]=1ll*b[i]*c[i]%Mod;
}
NTT(b,-1,len<<1);
for(int i=len;i<(len<<1);i++){
b[i]=c[i]=0;
}
for(int i=0;i<len;i++){
printf("%d ",b[i]);
}
puts("");
}
int f[Maxn+5],g[Maxn+5];
int frac[Maxn+5],inv_f[Maxn+5];
void init(){
frac[0]=1;
for(int i=1;i<=Maxn;i++){
frac[i]=1ll*frac[i-1]*i%Mod;
}
inv_f[Maxn]=quick_power(frac[Maxn],Mod-2,Mod);
for(int i=Maxn-1;i>=0;i--){
inv_f[i]=1ll*inv_f[i+1]*(i+1)%Mod;
}
}
int main(){
init();
scanf("%s",a+1);
scanf("%s",b+1);
while(a[++len]!=' ');
for(int i=1;i<=len;i++){
if(a[i]=='1'){
n++;
if(b[i]=='0'){
m++;
}
}
}
for(int i=0;i<=n-m;i++){
f[i]=inv_f[i+1];
}
int len=1;
while(len<=n-m){
len<<=1;
}
find_ln(f,g,len);
memset(f,0,sizeof f);
for(int i=0;i<=n-m;i++){
f[i]=1ll*g[i]*m%Mod;
}
memset(g,0,sizeof g);
find_exp(f,g,len);
int ans=0;
for(int i=0;i<=n-m;i++){
f[i]=g[i];
ans=(ans+f[i])%Mod;
}
ans=1ll*ans*frac[m]%Mod*frac[n-m]%Mod*frac[n]%Mod;
printf("%d
",ans);
return 0;
}