题目
有一个长度为 \(n\) 的数列 \(A\),\(m\) 给定。有 \(q\) 个修改或询问:
-
修改:把 \([l,r]\) 区间中的每个元素 \(x\) 变成 \(ax+b\bmod m\)。
-
询问:询问执行第 \(l\) 次到第 \(r\) 次操作后第 \(k\) 个元素的值。
强制在线,\(n\le 10^5,q\le 6\times 10^5\)。
解法
\(\text{Step 1}\) —— 离线
一个重要的观察是修改操作具有结合律:\(\text{Merge}(a_1x+b_1,a_2x+b_2)=a_1a_2x+(b_1a_2+b_2)\)。这提示了我们将修改放到线段树上合并。
当然我无视了这个提示,我甚至以为询问时的 "第 \(l\) 次到第 \(r\) 次操作" 都包含了第 \(k\) 个元素。于是我想一个数的答案不就是这个:
其中 \(mul_{i,j}\) 是第 \(i\) 次到第 \(j\) 次操作的 \(a\) 的累乘。考虑到 \(\sum\) 中每个 \(b_i\) 的系数的右端点出奇的一致:定义 \(h_i=b_i\cdot mul_{i+1,cnt}\)(一共有 \(cnt\) 次操作),\(s_i=\sum_{j=1}^i h_j\),式子可以表示成:
然后发现把题目看错了,好耶!
不过从中有一个收获:如果对于每个询问,正向计算第 \(l\) 次到第 \(r\) 次操作需要考虑到操作是否包含 \(k\),这是本题的难点;如果保证包含 \(k\),我们可以用线段树快速维护答案。你这样完全等于没说。
如何保证包含 \(k\)?从小到大枚举数列下标 \(i\):
- 修改:如果 \(i\) 是操作左端点,加入;反之,删除。具体维护以修改编号为下标的线段树,加入即为加入线段树的对应下标。
- 询问:此时在线段树中的修改都包含了 \(k\),所以直接询问 \([l,r]\) 即可。
\(\text{Step 2}\) —— 在线
还是维护以修改编号为下标的线段树,但对于树上的每个点 \(o\) 表示区间 \([l,r]\),维护之内操作对每个元素 \(A_i\) 的影响 \((a_i,b_i)\)。当然,我们不能把每个 \(A_i\) 的影响都存下来,但是我们发现影响可以被划分成一段一段地存储(每一段包含的 \(A_i\) 的影响相同)。计算一下这个区间最多会被分成多少段:每加入一个操作最多会增加两段,所以它的级别就是区间长度。考虑每个叶子节点只会被区间包含 \(\log n\) 次,所以总共段数为 \(n\log n\) 级别。合并段就用归并排序,整体复杂度是 \(\mathcal O(n\log n)\) 的。
对于询问需要二分 \(k\) 在哪一段内,所以是 \(\mathcal O(n\log^2 n)\)。
代码
#include<cstdio>
#include<iostream>
#define int long long
using namespace std;
const int N = 6e5 + 2, M = N * 30;
int ans, aa, bb, A, B, ql, qr, pos, type, n, m, v[N], Q, cnt, tot, L[N << 2], R[N << 2], ll[M], rr[M], a[M], b[M];
int read() {
int x = 0, f = 1;
char s = getchar();
while(s > '9' || s < '0') {
if(s == '-') f = -1;
s = getchar();
}
while(s >= '0' && s <= '9') {
x = (x << 1) + (x << 3) + (s ^ 48);
s = getchar();
}
return x * f;
}
void pushUp(const int o) {
int l1 = L[o << 1], l2 = L[o << 1 | 1], r1 = R[o << 1], r2 = R[o << 1 | 1], l0 = 1;
L[o] = tot + 1;
while(l1 <= r1 || l2 <= r2) {
ll[++ tot] = l0; rr[tot] = min(rr[l1], rr[l2]);
a[tot] = a[l1] * a[l2] % m;
b[tot] = (a[l2] * b[l1] + b[l2]) % m;
l0 = rr[tot] + 1;
if(rr[l1] < rr[l2]) ++ l1;
else if(rr[l1] > rr[l2]) ++ l2;
else ++ l1, ++ l2;
}
R[o] = tot;
}
void change(const int o, const int l, const int r) {
if(l > cnt || r < cnt) return;
if(l == r) {
L[o] = tot + 1;
if(ql > 1) ll[++ tot] = 1, rr[tot] = ql - 1, a[tot] = 1;
ll[++ tot] = ql, rr[tot] = qr, a[tot] = A, b[tot] = B;
if(qr < n) ll[++ tot] = qr + 1, rr[tot] = n, a[tot] = 1;
R[o] = tot;
return;
}
int mid = l + r >> 1;
change(o << 1, l, mid);
change(o << 1 | 1, mid + 1, r);
if(r == cnt) pushUp(o);
}
int find(int l, int r) {
int res = 0;
while(l <= r) {
int mid = l + r >> 1;
if(rr[mid] < pos) l = mid + 1;
else r = mid - 1, res = mid;
}
return res;
}
void ask(const int o, const int l, const int r) {
if(ql > r || qr < l) return;
if(l >= ql && r <= qr) {
int id = find(L[o], R[o]);
aa = aa * a[id] % m;
bb = (a[id] * bb + b[id]) % m;
return;
}
int mid = l + r >> 1;
ask(o << 1, l, mid);
ask(o << 1 | 1, mid + 1, r);
}
signed main() {
int op, a, b;
type = (read() & 1), n = read(), m = read();
for(int i = 1; i <= n; ++ i) v[i] = read();
Q = read();
for(int i = 1; i <= Q; ++ i) {
op = read(), ql = read(), qr = read();
if(type) ql ^= ans, qr ^= ans;
if(op == 1) {
A = read() % m, B = read() % m;
++ cnt;
change(1, 1, Q);
}
else {
pos = read();
if(type) pos ^= ans;
aa = 1, bb = 0;
ask(1, 1, Q);
ans = (aa * v[pos] + bb) % m;
printf("%lld\n", ans);
}
}
return 0;
}