Solution
首先枚举 x, y 的高度 i,分三种情况讨论:
1. x, y 一个在左边,一个在右边:
如图,两个方框分别代表 x, y,红线把数列分成 4 段,考虑对于每一段分别计算答案,然后相乘。对于每一段,其长度已经固定,取数范围也已经固定。例如,对于 l1,其长度为 x - 1,取数范围为 ([1,i])。组合数时使用可重复组合,计算公式为 (H_n^r = C_{n+r-1}^r)。
2. x, y 都在左边 仿照上例。
3. x, y 都在右边 仿照上例。
计算组合数时需要用到乘法逆元,可以预处理出阶乘的逆元。注意经常取模,稍微不注意就会爆 long long。
Code
#include <iostream>
#include <cstdio>
#include <cstring>
#define LL long long
using namespace std;
const int N = 6666666;
const LL mod = 998244353;
LL m, n, x, y, Ans = 0, inv[N], fac[N];
LL poww(LL b, LL p)
{
LL res = 1;
while(p)
{
if(p & 1) res = (res * b) % mod;
b = (b * b) % mod;
p >>= 1;
}
return res;
}
LL C(LL x, LL y) { return (fac[x] * ((inv[y] * inv[x - y]) % mod)) % mod; }
int main()
{
scanf("%lld%lld%lld%lld", &m, &n, &x, &y);
fac[0] = 1;
for(int i = 1; i <= n * 2 + m; i++)
fac[i] = (fac[i - 1] * i) % mod;
inv[n * 2 + m] = poww(fac[n * 2 + m], mod - 2);
inv[0] = 1;
for(int i = n * 2 + m - 1; i > 0; i--)
inv[i] = (inv[i + 1] * (i + 1) % mod + mod) % mod;
if(x <= n && y >= n)
{
for(int i = 1; i <= m; i++)
{
LL l1 = x - 1, l2 = n * 2 - y, l3 = n - x, l4 = y - n - 1, now = 1;
LL num1 = i, num2 = i, num3 = m - i + 1, num4 = m - i + 1;
now = (now * C(num1 + l1 - 1, l1) + mod) % mod;
now = (now * C(num2 + l2 - 1, l2) + mod) % mod;
now = (now * C(num3 + l3 - 1, l3) + mod) % mod;
now = (now * C(num4 + l4 - 1, l4) + mod) % mod;
Ans = (Ans + now + mod) % mod;
}
}
else if(y < n)
{
for(int i = 1; i <= m; i++)
{
LL l1 = x - 1, l2 = y - x - 1, l3 = n - y, l4 = n;
LL num1 = i, num2 = 1, num3 = m - i + 1, num4 = m, now = 1;
now = (now * C(num1 + l1 - 1, l1) + mod) % mod;
now = (now * C(num2 + l2 - 1, l2) + mod) % mod;
now = (now * C(num3 + l3 - 1, l3) + mod) % mod;
now = (now * C(num4 + l4 - 1, l4) + mod) % mod;
Ans = (Ans + now + mod) % mod;
}
}
else
{
for(int i = 1; i <= m; i++)
{
LL l1 = n, l2 = x - n - 1, l3 = y - x - 1, l4 = n * 2 - y;
LL num1 = m, num2 = m - i + 1, num3 = 1, num4 = i, now = 1;
now = (now * C(num1 + l1 - 1, l1) + mod) % mod;
now = (now * C(num2 + l2 - 1, l2) + mod) % mod;
now = (now * C(num3 + l3 - 1, l3) + mod) % mod;
now = (now * C(num4 + l4 - 1, l4) + mod) % mod;
Ans = (Ans + now + mod) % mod;
}
}
printf("%lld", Ans);
return 0;
}