题目
题解
一个数最多加(log n)次lowbit,之后只需乘2即可。因此可以结合线段树暴力,没好的暴力加,加好的直接打标记乘2。
原本我的方法是并查集维护那些区间乘2,那些区间暴力加,并查集合并。这样做时间复杂度相似,但是常数巨大。除了并查集本身的复杂度,每次更新都是从线段树的([1,n])第一层开始向下更新,这是常数很大的(O(log n))。
因此直接在线段树里维护就好,用一个tag
数组标记当前区间是需要递归下去加lowbit还是直接lazy标记乘2。
#include <bits/stdc++.h>
#define endl '
'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N)
typedef long long ll;
using namespace std;
/*-----------------------------------------------------------------*/
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f
const int N = 3e5 + 10;
const int M = 998244353;
const double eps = 1e-5;
int arr[N];
ll sum[N << 2], pw[N];
int tag[N << 2], lazy[N << 2];
ll lowbit(ll x) {
return x&-x;
}
void pushdown(int rt) {
if(lazy[rt]) {
lazy[rt << 1] += lazy[rt];
lazy[rt << 1 | 1] += lazy[rt];
sum[rt << 1] = sum[rt << 1] * pw[lazy[rt]] % M;
sum[rt << 1 | 1] = sum[rt << 1 | 1] * pw[lazy[rt]] % M;
lazy[rt] = 0;
}
}
void init(int l, int r, int rt) {
if(l == r) {
sum[rt] = arr[l];
tag[rt] = (arr[l] == lowbit(arr[l]));
return ;
}
int mid = (l + r) / 2;
init(l, mid, rt << 1);
init(mid + 1, r, rt << 1 | 1);
sum[rt] = (sum[rt << 1] + sum[rt << 1 | 1]) % M;
tag[rt] = (tag[rt << 1] && tag[rt << 1 | 1]);
}
void update(int l, int r, int L, int R, int rt) {
if(tag[rt] && l >= L && r <= R) {
lazy[rt]++;
sum[rt] = sum[rt] * 2 % M;
return ;
}
if(l == r) {
sum[rt] += lowbit(sum[rt]);
if(sum[rt] == lowbit(sum[rt])) tag[rt] = 1;
return ;
}
pushdown(rt);
int mid = (l + r) /2;
if(L <= mid) update(l, mid, L, R, rt << 1);
if(R > mid) update(mid + 1, r, L, R, rt << 1 | 1);
sum[rt] = (sum[rt << 1] + sum[rt << 1 | 1]) % M;
tag[rt] = (tag[rt << 1] && tag[rt << 1 | 1]);
}
ll que(int l, int r, int L, int R, int rt) {
if(l >= L && r <= R) {
return sum[rt];
}
pushdown(rt);
ll res = 0;
int mid = (l + r) /2;
if(L <= mid) res += que(l, mid, L, R, rt << 1);
if(R > mid) res += que(mid + 1, r, L, R, rt << 1 | 1);
sum[rt] = (sum[rt << 1] + sum[rt << 1 | 1]) % M;
tag[rt] = (tag[rt << 1] && tag[rt << 1 | 1]);
return res % M;
}
int main() {
pw[0] = 1;
for(int i = 1; i < N; i++) pw[i] = pw[i - 1] * 2 % M;
IOS;
int t;
cin >> t;
while(t--) {
int n;
cin >> n;
for(int i = 1; i <= n; i++) cin >> arr[i];
init(1, n, 1);
int q;
cin >> q;
while(q--) {
int op, l, r;
cin >> op >> l >> r;
if(op == 1) {
update(1, n, l, r, 1);
} else {
cout << que(1, n, l, r, 1) << endl;
}
}
}
}