题目
点这里看题目。
分析
一个暴力的 (O(nm)) 的 DP 不难看出:
(f(i,j)):当前有 (i) 个 YES ,(j) 个 NO 的时候的期望的最大猜对数。
转移的时候,我们肯定选择猜对概率大的那个作为猜测的答案,也就是有;而当前问题有 (frac{i}{i+j}) 的概率是 YES ,剩下的概率就是 NO ,也就是说:
[f(i,j)=frac{i}{i+j}f(i-1,j)+frac{j}{i+j}f(i,j-1)+frac{max{i,j}}{i+j}
]
然后它是 (O(nm)) 的,对。
考虑优化这个过程。以下假设 (nge m) 。
显然我们是在进行从 ((n,m)) 到 ((0,0)) 的游走。每当我们走过 (i=j) 这条对角线的时候, (max{i,j}) 的值就会变动。
游走的时候,如果不经过 (i=j) 这条线,我们只会才 YES ,因而只会得到 (n) 的贡献。
如果经过了 (i=j) 这条线,设路径为 (P) ,将它对角线之上的部分翻折到下面来,得到新路径 (P') 。此时 (P') 的贡献也是 (n) ;而 (P) 和 (P') 经过的概率还是一样的。因而经过了 (i=j) 的这条线之后,我们仍然至少可以得到 (n) 的期望。
注意到这里的措辞是 " 至少 " 。事实上, (i = j) 的时候,我们有几率得到正确的答案,因而在 (i=j) 的时候我们会得到 (frac{1}{2}) 的额外期望。
基于期望的线性性,我们可以对每个对角线上的点 ((k, k)) ,计算它的额外期望对答案的贡献,我们只需要考虑经过它的概率:
[frac{inom{m-k+n-k}{m-k} imesinom{2k}{k}}{inom{n+m}{m}}
]
其中分子是经过 ((k,k)) 的路径条数,分母是总路径条数。
因而经过对角线的额外期望就是:
[frac12sum_{k=1}^mfrac{inom{m-k+n-k}{m-k}inom{2k}{k}}{inom{n+m}{m}}
]
计算的时间是 (O(n)) 的。
代码
#include <cstdio>
const int mod = 998244353;
const int MAXN = 2e6 + 5;
template<typename _T>
void read( _T &x )
{
x = 0;char s = getchar();int f = 1;
while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
x *= f;
}
template<typename _T>
void write( _T x )
{
if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
if( 9 < x ){ write( x / 10 ); }
putchar( x % 10 + '0' );
}
template<typename _T>
void swapp( _T &x, _T &y )
{
_T t = x; x = y, y = t;
}
int fac[MAXN], ifac[MAXN];
int N, M;
int qkpow( int, int );
int inv( const int a ) { return qkpow( a, mod - 2 ); }
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }
int C( const int n, const int m ) { return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod; }
int invC( const int n, const int m ) { return 1ll * ifac[n] * fac[m] % mod * fac[n - m] % mod; }
int qkpow( int base, int indx )
{
int ret = 1;
while( indx )
{
if( indx & 1 ) ret = 1ll * ret * base % mod;
base = 1ll * base * base % mod, indx >>= 1;
}
return ret;
}
void init( const int siz )
{
fac[0] = ifac[0] = 1;
for( int i = 1 ; i <= siz ; i ++ ) fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[siz] = inv( fac[siz] );
for( int i = siz - 1 ; i ; i -- ) ifac[i] = 1ll * ifac[i + 1] * ( i + 1 ) % mod;
}
int main()
{
read( N ), read( M );
if( N < M ) swapp( N, M );
init( N + M );
int ans = 0, tmp = invC( N + M, N );
for( int k = 1 ; k <= M ; k ++ )
add( ans, 499122177ll * C( M - k + N - k, M - k ) % mod * C( 2 * k, k ) % mod * tmp % mod );
add( ans, N ); write( ans ), putchar( '
' );
return 0;
}