给定 (n imes m) 的自然数矩阵 (a),求有多少个 ((r_1,r_2,c_1,c_2)) 满足 (1le r_1le r_2le n-2),(1le c_1le c_2le m-2),(forall iin[r_1,r_2],jin[c_1,c_2]),(a_{i,j}<min(a_{i,c_1-1},a_{i,c_2+1},a_{r_1-1,j},a_{r_2+1,j}))。
(n,mle 2500),(a_{i,j}le 7cdot 10^6)。
首先可以发现,充要条件就是 (forall iin[r_1,r_2]),(a_{i,c_1-1}) 右边第一个不小于它的值是 (a_{i,c_2+1}) 或 (a_{i,c_2+1}) 左边第一个不小于它的值是 (a_{i,c_1-1}),列同理。
对于每一行,满足条件的 ((c_1,c_2)) 可以用单调栈算出来,并且只有至多 (2m) 个。列同理。
考虑如何计算,先对每行算一遍,把所有 ((c_1,c_2)) 对应的行标号记录下来,然后从小到大枚举右边界 (c_2),把这一列算一遍,维护所有 ((r_1,r_2)) 对应的当前最右连续区间 ([lb_{r_1,r_2},rb_{r_1,r_2}]),然后枚举左边界 (c_1),遍历 ((c_1,c_2)) 对应的行标号连续段。
现在已经去掉了行限制,至于列限制就直接对其中一列做一遍,可能满足条件的 ((r_1,r_2)) 只有至多 (2n) 个,判断就看 ([lb_{r_1,r_2},rb_{r_1,r_2}]) 是否包含 ([c_1,c_2])。
时间复杂度 (O(nm))。
#include<bits/stdc++.h>
#define PB emplace_back
#define fi first
#define se second
using namespace std;
typedef pair<int, int> pii;
const int N = 2502, K = 5e7;
char buf[K], *in = buf;
int read(){
int x = 0;
for(;!isdigit(*in);++ in);
for(;isdigit(*in);++ in) x = x * 10 + *in - '0';
return x;
}
int n, m, a[N][N], lb[N][N], rb[N][N], stk[N], tp, ans;
vector<int> ok[N][N];
vector<pii> res;
void work(int *b, int l){
res.resize(tp = 0);
for(int i = 1;i <= l;++ i){
while(tp && b[i] > b[stk[tp]]){
if(i > stk[tp]+1) res.PB(stk[tp]+1, i-1);
-- tp;
}
if(tp){
if(i > stk[tp]+1) res.PB(stk[tp]+1, i-1);
if(b[i] == b[stk[tp]]) -- tp;
}
stk[++tp] = i;
}
}
void calc(int l, int r, int u, int d){
for(int i = u-1;i <= d+1;++ i)
a[0][i-u+2] = a[i][l];
work(*a, d-u+3);
for(pii p : res){
int L = p.fi+u-2, R = p.se+u-2;
ans += (lb[L][R] <= l && r <= rb[L][R]);
}
}
int main(){
fread(buf, 1, K, stdin);
n = read(); m = read();
for(int i = 1;i <= n;++ i)
for(int j = 1;j <= m;++ j)
a[i][j] = read();
for(int i = 2;i < n;++ i){
work(a[i], m);
for(pii p : res) ok[p.fi][p.se].PB(i);
}
for(int r = 2;r < m;++ r){
for(int i = 1;i <= n;++ i)
a[0][i] = a[i][r];
work(*a, n);
for(pii p : res){
if(rb[p.fi][p.se] < r-1) lb[p.fi][p.se] = r;
rb[p.fi][p.se] = r;
}
for(int l = 2;l <= r;++ l){
int len = ok[l][r].size();
if(!len) continue;
int lst = ok[l][r][0];
for(int i = 1;i < len;++ i)
if(ok[l][r][i] > ok[l][r][i-1]+1){
calc(l, r, lst, ok[l][r][i-1]);
lst = ok[l][r][i];
}
calc(l, r, lst, ok[l][r][len-1]);
}
}
printf("%d
", ans);
}