Problem
Description
你有一个 (N) 行、(M) 列的、每个格子都填写着 0 的表格。你进行了下面的操作:
- 对于每一行 (i) ,选定自然数 (r_i) ((0leq r_ileq M)),将这一行最左边的 (r_i) 个格子中的数 (+1).
- 对于每一列 (i) ,选定自然数 (c_i) ((0leq c_ileq N)),将这一列最上边的 (c_i) 个格子中的数 (+1).
这样,根据你选定的 (r_1,r_2,ldots,r_N,c_1,c_2,ldots,c_M) ,你就得到了一个每个格子要么是 0,要么是 1,要么是 2 的一个最终的表格。问本质不同的最终表格有多少种。两个表格本质不同当且进当它们有一个对应格子中的数不同。
Range
(1leq N,M leq 5cdot 10^5)
Algorithm
容斥原理
Mentality
我们应该直接考虑重复的情况是怎么样的。
对于一对行和列,我们先假设其他行列的操作已经完成了,只需要考虑当前行列有多少种操作令结果不同。
然后缜密思索,我们发现只会有两个操作的结果是相同的。
假设我们正在考虑行 (i) 与列 (j) ,那么不难发现,只有当 (r_i=j,c_j=i-1) 和 (r_i=j-1,c_i=i) 这两种情况时,它们的结果会相同,对于其他任意情况而言,结果唯一。
则我们只需要枚举有几对行列选择了这两种会重复的状态的前一种,剩下的随便填,然后利用容斥原理计算答案即可。
对于枚举 (k) ,则有:
[f(k)=C^N_k*C^M_k*k!*(M+1)^{N-k}*(N+1)^{M-k}
]
则:
[ans=sum_{k=0}^{min(N,M)}(-1)^kf(k)
]
Code
#include <algorithm>
#include <cmath>
#include <complex>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <vector>
using namespace std;
long long read() {
long long x = 0, w = 1;
char ch = getchar();
while (!isdigit(ch)) w = ch == '-' ? -1 : 1, ch = getchar();
while (isdigit(ch)) {
x = (x << 3) + (x << 1) + ch - '0';
ch = getchar();
}
return x * w;
}
const int Max_n = 5e5 + 5, mod = 998244353;
int n, m, ans;
int fac[Max_n], ifac[Max_n];
int f[Max_n];
int ksm(int a, int b) {
int res = 1;
for (; b; b >>= 1, a = 1ll * a * a % mod)
if (b & 1) res = 1ll * res * a % mod;
return res;
}
int C(int n, int m) { return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod; }
int main() {
#ifndef ONLINE_JUDGE
freopen("F.in", "r", stdin);
freopen("F.out", "w", stdout);
#endif
n = read(), m = read();
if (n > m) swap(n, m);
fac[0] = ifac[0] = 1;
for (int i = 1; i <= m; i++) fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[m] = ksm(fac[m], mod - 2);
for (int i = m - 1; i; i--) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
for (int i = 0; i <= n; i++) {
f[i] = 1ll * C(n, i) * C(m, i) % mod * fac[i] % mod;
f[i] = 1ll * f[i] * ksm(m + 1, n - i) % mod * ksm(n + 1, m - i) % mod;
}
for (int i = 0; i <= n; i++)
ans = ((ans + ksm(-1, i & 1) * f[i]) % mod + mod) % mod;
cout << ans;
}