PS:因为本题要用到线段树,所以数组都是从(1)开始存的。
题意
给定一组有(n)个数的数组(a),如果(a_i = i),那么你可以将(a_i)消除,然后(a_i)后面的数下标都(-1)。一共有(q)次询问,每次询问会给定一组((x,y)),问如果不能动前(x)个数和后(y)个数,最多可以消除多少个数?
思路
PS:没有特殊说明,下文提到的(i)都是指的数组下标。
首先,对于每一次询问,我们都可以得到我们能操作的区间——(l = x + 1, r = n - y),那么我们能操作的区间就是([l, r])。
暴力
很容易想到暴力的思路:
我们从(l)开始枚举,如果遇到(a_i = i),那么(ans := ans + 1);如果遇到(a_i > i),那么这个数则不可能被消除;如果遇到(a_i < i),那么,如果此时的(ans geqslant i - a_i),那么说明我们可以在消除前(ans)个数时,可以消除(a_i),否则(a_i)不可能被消除。
这样,我们的时间复杂度是(O(nq))的。再看一下数据范围:(1 leq n,q leq 3 imes 10^5),稳TLE。
改进
本题有一个很麻烦的地方就在于,它会限制前(x)个数不能动。如果前(x)个数不动,后面就可能有一些元素,从(1)开始消除,它们是能消除的,但是从(x + 1)开始,它们就不能消除了。
也就是说,对于每一个可能被消除的元素(a_i),它都存在一个边界(j),当(x leqslant j)时,(a_i)是能被消除的,当(x > j)时,它就消除不了了。对于每个(i),我们称这个(j)为$ l[i](。那么,换句话说,当)l leqslant l[i]$时,这个元素能被消除;否则不能被消除。
那么,这个时候,显然答案就变成了,对于每一次询问的(l, r),我们看([l,r])中有多少个(l[i])落在区间([l,r])上。
那么,对于这个查询,我们就可以考虑维护一个线段树,枚举每个(i),它的(l[i])位置上的(val +1)即可,然后查询就直接用线段树区间查询([l,r])就是答案。
PS:这里会有个问题,就是如果我先一次性把从(1)到(n)的所有(l[i])都加进线段树里,那么我查询([l,r])的时候,可能会有(i > r),但是(l leqslant l[i] leqslant r)的情况,此时会把下标大于(r)的数也算进去,但是显然我们的答案不能算进去这一部分。
要解决这个问题,我们就一次性把所有查询读进来离线处理。我们把所有的查询按照(r)从小到大排序,然后处理到第(i)个的时候,就把所有(i leqslant query[j].r)的(a_i)放进线段树里。这样就能避免(i > r)那一部分元素的影响了。
求(l[i])
这是最后一个麻烦事。我们考虑(l[i])的含义:当(l leqslant l[i])时,这个元素能被消除;否则不能被消除。
所以,我们有:
- 当(a_i = i)时,显然,(l[i] = i);
- 当(a_i > i)时,显然,无论如何,我们都无法消除(a_i),此时我们给(l[i])赋值为(-1);
- 当(a_i < i)时,(a_i)有可能会被消除,也有可能无法被消除:
- 当(i)前面所有能消除的数都被消除以后,(a_i)仍然小于(i),那么此时,(a_i)无法被消除,我们给(l[i])赋值为(-1);
- 当(i)前面的(i - a_i)个数被消除以后,(a_i)就能被消除了。此时,我们消除最靠近(i)的(i - a_i)个数显然是最优的。
前面的都很简单,我们现在只考虑第三点:如何考虑消除这(i - a_i)个数?
首先,我们是从(1)到(n)枚举的。这保证了线段树上所有(+1)的点,都是(i)前面的数造成的;
然后,对于每个(l[i]),我们在线段数上在(l[i])这个位置(+1),这保证了一个能被消除的点的贡献一定是(1);
所以,我们要考虑消除最靠近(i)的(i - a_i)个点,就可以变成,我们找能被消除的第(num - (i - a_i) + 1)个点。这个点,就是(l[i])。
所以,对于第三点,找(l[i])就变成了找一个(j),满足(sum(1,j) = num - (i - a_i) + 1)。(也可以理解成找一个(j),满足(sum(j, i - 1) = i - a_i))。其中(num)是从(1)到(i - 1)能被消除的数的总数。这个操作可以在线段树上(O(log n))完成,代码如下:
int findloc(int rt, int l, int r, const int &k) {
if (l == r)
return l;
int mid = (l + r) >> 1;
if (seg[rt << 1] >= k)
return findloc(rt << 1, l, mid, k);
else
return findloc(rt << 1 | 1, mid + 1, r, k - seg[rt << 1]);
}
其中的(k)就是上面说到的(j)。
代码
为了方便处理,我在代码中一开始就直接把所有的(a_i)替换成了(i - a_i)。线段树可以重复使用,不过重复使用前不要忘记build
一波初始化。
#include <cstdio>
#include <cstring>
#include <algorithm>
const int maxn = 3e5 + 5;
using namespace std;
struct Triple {
int l, r, id, ans;
}t[maxn];
int a[maxn], n, q, f[maxn], l[maxn];
int seg[maxn << 2];
inline int cmp(const Triple &a, const Triple &b) { // 按照r排序
return a.r < b.r;
}
inline int cmp2(const Triple &a, const Triple &b) { // 按照id排序,处理完了最后输出
return a.id < b.id;
}
/***********线段树***********/
inline void pushup(int rt) {
seg[rt] = seg[rt << 1] + seg[rt << 1 | 1];
}
// 用于初始化
void build(int rt, int l, int r) {
if (l == r) {
seg[rt] = 0;
return ;
}
int mid = (l + r) >> 1;
build(rt << 1, l, mid);
build(rt << 1 | 1, mid + 1, r);
pushup(rt);
}
// 单点修改
void update(int rt, int l, int r, const int &loc, const int &val) {
if (l == r) {
seg[rt] += val;
return ;
}
int mid = (l + r) >> 1;
if (loc <= mid) update(rt << 1, l, mid, loc, val);
else if (loc > mid) update(rt << 1 | 1, mid + 1, r, loc, val);
pushup(rt);
}
// 查询sum[1, r] == k的r的位置
int findloc(int rt, int l, int r, const int &k) {
if (l == r)
return l;
int mid = (l + r) >> 1;
if (seg[rt << 1] >= k)
return findloc(rt << 1, l, mid, k);
else
return findloc(rt << 1 | 1, mid + 1, r, k - seg[rt << 1]);
}
// 区间查询求和
int query(int rt, int l, int r, const int &L, const int &R) {
if (L <= l && r <= R) {
return seg[rt];
}
int mid = (l + r) >> 1;
int res = 0;
if (L <= mid) res += query(rt << 1, l, mid, L, R);
if (R > mid) res += query(rt << 1 | 1, mid + 1, r, L, R);
pushup(rt);
return res;
}
/***********线段树***********/
int main() {
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++)
scanf("%d", a + i);
for (int i = 1; i <= n; i++)
a[i] = i - a[i];
build(1, 1, n);
int num = 0;
for (int i = 1; i <= n; i++) {
if (a[i] < 0) l[i] = -1;
else if (a[i] == 0) {
l[i] = i;
num++;
update(1, 1, n, i, 1);
}
else {
int tmp = num - a[i] + 1;
if (tmp <= 0) l[i] = -1;
else {
l[i] = findloc(1, 1, n, tmp);
update(1, 1, n, l[i], 1);
num++;
}
}
}
for (int i = 0; i < q; i++) {
scanf("%d%d", &t[i].l, &t[i].r);
t[i].l++;
t[i].r = n - t[i].r;
t[i].id = i;
}
sort(t, t + q, cmp);
int cur = 1;
build(1, 1, n);
for (int i = 0; i < q; i++) {
while (cur <= t[i].r) {
if (l[cur] > 0)
update(1, 1, n, l[cur], 1);
cur++;
}
t[i].ans = query(1, 1, n, t[i].l, t[i].r);
}
sort(t, t + q, cmp2);
for (int i = 0; i < q; i++)
printf("%d
", t[i].ans);
return 0;
}