题意
求矩阵中上下对称且左右对称的正方形子矩阵的个数。
思路
快速对比矩阵是否上下对称或者左右对称可以考虑对二维矩阵哈希。
哈希之后通过枚举正方形中心点位置,对正方形边长进行二分。
在枚举正方形中心点时,需考虑正方形边长分别为奇偶的情况,即中心点为格子交接点还是格子中心点。
二维矩阵哈希值维护:
(hash[x][y] = hash[x][y-1] imes base1 + hash[x - 1][y] imes base2 + value[x][y])
二维矩阵哈希值求解:
(ans = hash[x][y] - hash[x - len_x][y] imes fac1[len_x] - hash[x][y -len_y] imes fac2[len_y] + hash[x - len_x][y - len_y] imes fac1[len_x] imes fac2[len_y])
其中 (fac1), (fac2) 分别为 (base1), (base2) 的 (len) 次方
代码
#include <cstdio>
#include <algorithm>
#include <iostream>
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
const int maxn = 2010;
const int base1 = 87;
const int base2 = 31;
int n, m;
int mp[maxn][maxn];
int lr[maxn][maxn];
int up[maxn][maxn];
ull hash_mp[maxn][maxn];
ull hash_lr[maxn][maxn];
ull hash_up[maxn][maxn];
ull fac1[maxn], fac2[maxn];
bool check(int x, int y, int len) {
if (x > n || y > m) return 0;
if (x < len || y < len) return 0;
ull ans1 = hash_mp[x][y] - hash_mp[x - len][y] * fac2[len] - hash_mp[x][y - len] * fac1[len]
+ hash_mp[x - len][y - len] * fac1[len] * fac2[len];
int cow_y = m - (y - len);
ull ans2 = hash_lr[x][cow_y] - hash_lr[x - len][cow_y] * fac2[len] - hash_lr[x][cow_y - len] * fac1[len]
+ hash_lr[x - len][cow_y - len] * fac1[len] * fac2[len];
if (ans1 != ans2) return 0;
int row_x = n - (x - len);
ull ans3 = hash_up[row_x][y] - hash_up[row_x - len][y] * fac2[len] - hash_up[row_x][y - len] * fac1[len]
+ hash_up[row_x - len][y - len] * fac1[len] * fac2[len];
if (ans1 != ans3) return 0;
return (ans1 == ans3 && ans2 == ans3);
}
void solve() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j)
scanf("%d", &mp[i][j]);
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
lr[i][j] = mp[i][m - j + 1];
up[i][j] = mp[n - i + 1][j];
}
}
fac1[0] = fac2[0] = 1;
for (int i = 1; i <= n; ++i) fac1[i] = fac1[i - 1] * base1;
for (int i = 1; i <= m; ++i) fac2[i] = fac2[i - 1] * base2;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
hash_mp[i][j] = hash_mp[i][j - 1] * base1 + mp[i][j];
hash_lr[i][j] = hash_lr[i][j - 1] * base1 + lr[i][j];
hash_up[i][j] = hash_up[i][j - 1] * base1 + up[i][j];
}
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
hash_mp[i][j] += hash_mp[i - 1][j] * base2;
hash_lr[i][j] += hash_lr[i - 1][j] * base2;
hash_up[i][j] += hash_up[i - 1][j] * base2;
}
}
ll ans = 0;
int R = max(n, m);
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
int l = 1, r = R;
while (l <= r) {
int mid = l + r >> 1;
if (check(i + mid, j + mid, mid << 1 | 1)) {
l = mid + 1;
} else {
r = mid - 1;
}
}
// cout << "~ " << l << " " << r << endl;
ans += r;
}
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
int l = 1, r = R;
while (l <= r) {
int mid = l + r >> 1;
if (check(i + mid, j + mid, mid << 1)) {
l = mid + 1;
} else {
r = mid - 1;
}
}
// cout << "~~ " << i << " " << j << " " << l << " " << r << endl;
ans += r;
}
}
printf("%lld
", ans + n * m);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--)
solve();
return 0;
}