[BZOJ2738]矩阵乘法
试题描述
给你一个 (N imes N) 的矩阵,不用算矩阵乘法,但是每次询问一个子矩形的第 (K) 小数。
输入
第一行两个数 (N,Q),表示矩阵大小和询问组数;
接下来 (N) 行 (N) 列一共 (N imes N) 个数,表示这个矩阵;
再接下来 (Q) 行每行 (5) 个数描述一个询问:(x_1,y_1,x_2,y_2,k) 表示找到以 ((x_1,y_1)) 为左上角、以 ((x_2,y_2)) 为右下角的子矩形中的第 (K) 小数。
输出
对于每组询问输出第 (K) 小的数。
输入示例
2 2
2 1
3 4
1 2 1 2 1
1 1 2 2 3
输出示例
1
3
数据规模及约定
矩阵中数字是 (10^9) 以内的非负整数;
(20 exttt{%}) 的数据:(N le 100,Q le 1000);
(40 exttt{%}) 的数据:(N le 300,Q le 10000);
(60 exttt{%}) 的数据:(N le 400,Q le 30000);
(100 exttt{%}) 的数据:(N le 500,Q le 60000)。
题解
所谓的整体二分,其实就是一种处理离线询问的方法。
令 (solve(l, r, S)) 表示处理答案区间为 ([l, r]) 的询问集合为 (S) 的部分。那么令 (m = lfloor frac{l +r}{2} floor),对于 (S) 中的询问,我们先将 ([l, m]) 中的数加入二维树状数组中,然后查询每个 (S) 中的询问,如果比询问中的 (k) 大,则该询问的答案在 ([l, m]) 中,否则在 ([m + 1, r]) 中。
注意整体二分的时间复杂度分析,很容易不小心写成 (O(n^2)) 的。(本题 (O(n^2log_2^3n + qlog_2^3n)))
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
#define rep(i, s, t) for(int i = (s); i <= (t); i++)
#define dwn(i, s, t) for(int i = (s); i >= (t); i--)
int read() {
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}
#define maxn 510
#define maxq 60010
#define pii pair <int, int>
#define x first
#define y second
#define mp(x, y) make_pair(x, y)
int n, q, A[maxn][maxn], num[maxn*maxn], ToT, head[maxn*maxn], nxt[maxn*maxn];
pii pos[maxn*maxn];
struct Que {
int x1, y1, x2, y2, k, id;
bool type;
Que() {}
Que(int _1, int _2, int _3, int _4, int _k, int _id): x1(_1), y1(_2), x2(_3), y2(_4), k(_k), id(_id) {}
} qs[maxq], tqs[maxq];
int C[maxn][maxn];
void Add(int x, int y, int v) {
for(; x <= n; x += x & -x)
for(int Y = y; Y <= n; Y += Y & -Y) C[x][Y] += v;
return ;
}
int que(int x, int y) {
int sum = 0;
for(; x; x -= x & -x)
for(int Y = y; Y; Y -= Y & -Y) sum += C[x][Y];
return sum;
}
int Query(int x1, int y1, int x2, int y2) {
return que(x2, y2) - que(x1 - 1, y2) - que(x2, y1 - 1) + que(x1 - 1, y1 - 1);
}
int Ans[maxq];
void solve(int l, int r, int ql, int qr) {
if(ql > qr) return ;
if(l == r) {
rep(i, ql, qr) Ans[qs[i].id] = num[l];
return ;
}
int mid = l + r >> 1;
// printf("[%d, %d] %d [%d, %d]
", l, r, mid, ql, qr);
rep(v, l, mid)
for(int i = head[v]; i; i = nxt[i]) Add(pos[i].x, pos[i].y, 1);
int cnt = 0, lim;
rep(i, ql, qr) {
// printf("(%d, %d)(%d, %d)
", qs[i].x1, qs[i].y1, qs[i].x2, qs[i].y2);
int tmp = Query(qs[i].x1, qs[i].y1, qs[i].x2, qs[i].y2);
if(qs[i].k > tmp) qs[i].k -= tmp, qs[i].type = 1;
else qs[i].type = 0, tqs[++cnt] = qs[i];
}
rep(v, l, mid)
for(int i = head[v]; i; i = nxt[i]) Add(pos[i].x, pos[i].y, -1);
lim = cnt;
rep(i, ql, qr) if(qs[i].type) tqs[++cnt] = qs[i];
rep(i, ql, qr) qs[i] = tqs[i-ql+1];
solve(l, mid, ql, ql + lim - 1); solve(mid + 1, r, ql + lim, qr);
return ;
}
int main() {
n = read(); q = read();
int cntn = 0;
rep(i, 1, n) rep(j, 1, n) num[++cntn] = A[i][j] = read();
rep(i, 1, q) {
int x1 = read(), y1 = read(), x2 = read(), y2 = read(), K = read();
qs[i] = Que(x1, y1, x2, y2, K, i);
}
sort(num + 1, num + cntn + 1);
rep(i, 1, n) rep(j, 1, n)
A[i][j] = lower_bound(num + 1, num + cntn + 1, A[i][j]) - num,
pos[++ToT] = mp(i, j), nxt[ToT] = head[A[i][j]], head[A[i][j]] = ToT;
solve(1, cntn, 1, q);
rep(i, 1, q) printf("%d
", Ans[i]);
return 0;
}