题意
给定一个 (n imes m) 的网格,其中有 (k) 个关键点,求所有至少含有一个关键点的子矩形所含关键点数的方差。
(n, m le 10^9, k le 2000)
思路
令 (s_0, s_1, s_2) 为所有合法子矩形所含关键点数的 (0, 1, 2) 次和,容易推出方差与 (s_{0, 1, 2}) 的关系。
(s_n (n > 0)) 容易计算,考虑 (x^n) 的组合意义就是有序可重复地取出 (n) 个点,因此枚举 (n) 个点,计算包含它们的矩形数量即可,复杂度 (mathcal O(n^2))。
比较难算的是 (s_0)。考虑对矩形中的唯一的一个点进行计数,这里计数以 (x) 为第一关键字,(y) 为第二关键字排序的最后一个点。
先对点排序,下文用 ([x_l, x_r], [y_l, y_r]) 表示一个子矩形,按顺序枚举点 ((x, y))。
-
显然 (x_l in [1, x]) 都是合法的。
-
依次扫描位于 ((x, y)) 之后的点,这些点不能被子矩形覆盖,因此每个点会对 (y_l, y_r) 中的一个加以限制,每加入一个点计算器 ([x_i, x_{i+1})) 贡献即可。
复杂度 (mathcal O(n^2))。
代码
#include <cstdio>
#include <algorithm>
using namespace std;
#define File(s) freopen(s".in", "r", stdin), freopen(s".out", "w", stdout)
const int mod = 998244353;
inline int add(int x, int y) {return x+y>=mod ? x+y-mod : x+y;}
inline int sub(int x, int y) {return x-y<0 ? x-y+mod : x-y;}
inline int mul(int x, int y) {return 1LL * x * y % mod;}
inline void inc(int &x, int y=1) {x += y; if(x >= mod) x -= mod;}
inline void dec(int &x, int y=1) {x -= y; if(x < 0) x += mod;}
inline int power(int x, int y){
int res = 1;
for(; y; y>>=1, x = mul(x, x)) if(y & 1) res = mul(res, x);
return res;
}
inline int inv(int x){return power(x, mod - 2);}
template<class T> void upmax(T &x, T y){x = x>y ? x : y;}
template<class T> void upmin(T &x, T y){x = x<y ? x : y;}
const int N = 2005;
struct Pt{
int x, y;
}a[N];
int main(){
int n, m, k;
scanf("%d%d%d", &n, &m, &k);
for(int i=1; i<=k; i++)
scanf("%d%d", &a[i].x, &a[i].y);
sort(a + 1, a + 1 + k, [](Pt x, Pt y){
if(x.x == y.x) return x.y < y.y;
return x.x < y.x;
});
int s0 = 0, s1 = 0, s2 = 0;
for(int i=1; i<=k; i++)
inc(s1, mul(mul(a[i].x, n - a[i].x + 1), mul(a[i].y, m - a[i].y + 1)));
for(int i=1; i<=k; i++)
for(int j=1; j<=k; j++){
int xl = min(a[i].x, a[j].x), xr = max(a[i].x, a[j].x);
int yl = min(a[i].y, a[j].y), yr = max(a[i].y, a[j].y);
inc(s2, mul(mul(xl, n - xr + 1), mul(yl, m - yr + 1)));
}
a[k + 1].x = a[k].x;
for(int i=1; i<=k; i++){
int yl = 0, yr = m + 1;
int now = mul(a[i + 1].x - a[i].x, mul(a[i].y, m - a[i].y + 1));
for(int j=i+1; j<=k; j++){
if(a[j].y <= a[i].y) upmax(yl, a[j].y);
if(a[j].y >= a[i].y) upmin(yr, a[j].y);
if(a[j].x != a[j + 1].x)
inc(now, mul(mul(yr - a[i].y, a[i].y - yl), a[j + 1].x - a[j].x));
}
inc(now, mul(mul(yr - a[i].y, a[i].y - yl), n - a[k].x + 1));
inc(s0, mul(now, a[i].x));
}
int avg = mul(s1, inv(s0));
printf("%d
", add(mul(sub(s2, mul(2, mul(s1, avg))), inv(s0)), mul(avg, avg)));
return 0;
}