题目大意
构造一棵([1,n])的线段树,有(q)个询问([x,y]),每次查询([x,y])的所有子区间在线段树上经过的点数之和。
(n,q leq 500000)
Solution
一开始方向错了。。。。
显然线段树上只有和([x,y])有交集的区间才会产生贡献。
设该点代表区间为([l,r]):
- 若([l,r])包含([x,y]),则([x,y])的所有子区间都会经过该点。
- 若([l,r])与([x,y])相交,那么只有和([l,r])有交集的子区间有贡献,于是用([x,y])所有的子区间减去和([l,r])没有交的子区间。
- 若([x,y])包含([l,r]),这时([x,y])的子区间必须和([l,r])有交集且不能包含([l,r])的父区间,同样做一下减法就行了。
现在问题是,第一、二种情况的([l,r])都是(O(logn))个的,可以暴力做。第三种情况,若([x,y])包含了([l,r]),([x,y])肯定也包含了([l,r])的所有子区间,不能暴力下去处理答案。如果我们把([l,r])对([x,y])的贡献写出来,是一个和(x,y)有关的二次多项式,于是我们可以维护每个区间各项的系数,统计一下子树内系数之和就行了。
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
#define lson rt << 1
#define rson rt << 1 | 1
using namespace std;
typedef long long ll;
const int N = 500007;
int n, q, opt;
ll l, r, a, b, lastans;
ll C2(ll n) {
return n * (n + 1) / 2;
}
ll sum[N << 2][3];
void pre(int rt, int l, int r, int fl, int fr) {
if (l != 1 || r != n) {
sum[rt][0] = 2 * l - 2 * fr;
sum[rt][1] = 2 * r - 2 * fl;
sum[rt][2] = -l - 1ll * l * l + r - 1ll * r * r + 2ll * fl * fr - 2 * fl + 2 * fr;
}
if (l == r) return;
int mid = l + r >> 1;
pre(lson, l, mid, l, r), pre(rson, mid + 1, r, l, r);
for (int i = 0; i < 3; ++i) sum[rt][i] += sum[lson][i] + sum[rson][i];
}
void go(int rt, int l, int r, int ql, int qr) {
if (l <= ql && qr <= r) lastans += C2(qr - ql + 1);
if (l < ql && r >= ql && r < qr) lastans += C2(qr - ql + 1) - C2(qr - r);
if (l > ql && l <= qr && r > qr) lastans += C2(qr - ql + 1) - C2(l - ql);
if (ql <= l && r <= qr) {
if (ql != l || r != qr) lastans += C2(qr - ql + 1) - C2(l - ql) - C2(qr - r);
if (l != r) lastans += ((sum[lson][0] + sum[rson][0]) * ql + (sum[lson][1] + sum[rson][1]) * qr + sum[lson][2] + sum[rson][2]) / 2;
return;
}
int mid = l + r >> 1;
if (ql <= mid) go(lson, l, mid, ql, qr);
if (mid + 1 <= qr) go(rson, mid + 1, r, ql, qr);
}
int main() {
freopen("ran.in", "r", stdin);
freopen("ran.out", "w", stdout);
scanf("%d%d%d", &n, &q, &opt);
pre(1, 1, n, 1, n);
while (q--) {
scanf("%lld%lld", &l, &r);
a = (l ^ (lastans * opt)) % n + 1, b = (r ^ (lastans * opt)) % n + 1;
l = min(a, b), r = max(a, b), lastans = 0;
go(1, 1, n, l, r);
printf("%lld
", lastans);
}
return 0;
}