题目
点这里看题目。
分析
题目明显是要求我们求方案数。
显然这道题没有办法直接做。
考虑转化一下题目条件。可以发现我们应该让 (A) 中多余的 1 换到 (A) 中缺少 1 的位置去。为了使描述更加清晰,我们这样定义:
-
公共点((P)):满足 (A_i=1land B_i=1) 的 (i)。可以发现无论怎么交换,最终 (P) 上总是 1 。
-
起点((S)):满足 (A_i=0land B_i=1) 的 (i)。我们需要将 (S) 上的 0 转移走。
-
终点((E)):满足 (A_i=1land B_i=0) 的 (i)。我们需要将 (S) 上的 0 转移到 (E) 上来。
可以发现,最终可以使得 (A=B) 的操作序列必然满足:
连接边 ((a_i,b_i)),则图的形态应该是一大堆由 (S) 和 (E) 作为端点,(P) 作为中间点的链。
注意这里的链应该是“有向”的,即我们不能倒着操作一条链。
好的,这样已经清晰多了。我们考虑写出状态和转移:
(f(i,j)):使用了 (i) 个 (P) ,组成了 (j) 条链的真实序列方案数。
不难考虑转移:
-
加入一个新的 (P) 。首先我们应该选取它所在的链((j)),钦定它在末尾,再考虑它的标号((i))。此时的贡献就是 (f(i-1,j) imes i imes j)。
-
加入一个新的链。我们继续钦定它放在末尾,并且考虑 (S) 和 (E) 的标号((j^2))。此时的贡献就是 (f(i,j-1) imes j^2)。
需要注意的是,每次转移必然会导致真实序列(也就是 (a) 和 (b))长度加一。我们同样钦定每次新增后放在末尾。
真实情况下,一条链可能会有许多种对应的真实序列,而同一条链的不同的真实序列是由不同转移顺序来区分的。
于是就有转移:
考虑统计答案。注意我们不一定要所有的 (P) 都在 (S-E) 链上。因此我们需要枚举一下不在链上的 (P) 的数量。
设 (P) 点有 (s) 个,(S) 和 (E) 各有 (t) 个。
因此有答案为:
其中 (inom{s}{i} imes (i!)^2) 是在计算不在链上的 (P) 的带标号形态, (inom{s+t}{i}) 是在合并两个序列。
最后我们就得到了时间为 (O(n^2)) 的算法。
本题一些有价值的点:
-
考虑序列的相同与不同,于是就有了 (P,S,E) 三种点。
-
将交换看成边,发现链的性质。同时这也令人想起 树上的数 。
-
考虑序列的 DP 的时候,要么考虑标号,要么考虑位置,同时考虑会算重。
代码
#include <cstdio>
const int mod = 998244353;
const int MAXN = 10005;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
int f[MAXN][MAXN];
int fac[MAXN], ifac[MAXN];
char A[MAXN], B[MAXN];
int N;
int qkpow( int base, int indx )
{
int ret = 1;
while( indx )
{
if( indx & 1 ) ret = 1ll * ret * base % mod;
base = 1ll * base * base % mod, indx >>= 1;
}
return ret;
}
void init( const int siz )
{
fac[0] = 1;
for( int i = 1 ; i <= siz ; i ++ ) fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[siz] = qkpow( fac[siz], mod - 2 );
for( int i = siz - 1 ; ~ i ; i -- ) ifac[i] = 1ll * ifac[i + 1] * ( i + 1 ) % mod;
}
int C( const int n, const int m )
{
if( n < m || n < 0 || m < 0 ) return 0;
return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
int main()
{
int S = 0, T = 0;
scanf( "%s%s", A + 1, B + 1 );
for( N = 1 ; A[N] ; N ++ )
{
int a = A[N] - '0', b = B[N] - '0';
if( a && b ) S ++;
if( a && ! b ) T ++;
}
init( N );
f[0][0] = 1;
for( int i = 0 ; i <= S ; i ++ )
for( int j = 0 ; j <= T ; j ++ )
{
if( i ) add( f[i][j], 1ll * f[i - 1][j] * i % mod * j % mod );
if( j ) add( f[i][j], 1ll * f[i][j - 1] * j % mod * j % mod );
}
int ans = 0;
for( int i = 0 ; i <= S ; i ++ )
add( ans, 1ll * C( S + T, i ) * C( S, i ) % mod * fac[i] % mod * fac[i] % mod * f[S - i][T] % mod );
write( ans ), putchar( '
' );
return 0;
}