P4735 最大异或和(可持久化trie)
题目描述
给定一个非负整数序列 ({a}),初始长度为 n。
有 m 个操作,有以下两种操作类型:
A x:添加操作,表示在序列末尾添加一个数 x,序列的长度 (n + 1)。
Q l r x:询问操作,你需要找到一个位置 p,满足(l le p le r),使得: (a[p] oplus a[p+1] oplus ... oplus a[N] oplus x) 最大,输出最大是多少。
思路
异或满足可加性,所以对 ({a}) 求前缀和 ({s}), (a[p] oplus a[p + 1] oplus a[p + 2] ... oplus a[n] oplus x = s[p - 1] oplus s[n] oplus x)。问题转化为已知数 (s[n] oplus x) 求 ({s}) 中的一个数满足在给定区间并且和 (s[n] oplus x) 异或起来最大。
首先,不考虑区间限制,单纯找一个数和已知数异或和最大的话,可以对 ({s}) 构建 trie 树。从最高位按位遍历已知数字,如果 trie 树上当前结点存在和这一位相反的数,就去相反的儿子,实在不行再去相同的儿子。这样贪心找出来的数和已知数字的异或和最大。
至于区间限制,导致 trie 树上的一些在 l - 1 之前加入的结点不能用, 在 r 及之后加入的结点不能用(注意这里l-1能选因为异或s[l-1]是保留到l;r不能选是因为异或s[r]连r也删了)。对于 r 的限制,我们可以可持久化,也就是记录所有时间内 tire 树的状态,也就是一共有 n 棵 trie, 我们在第 r - 1 棵trie上进行贪心的按位选择。对于 l 的限制,我们可以记录 trie 树上每一个结点的最晚时间戳 lst[]
,最晚时间戳要小于 l - 1 的结点也不能用。
可持久化就是在建第 i 棵树的时候尽可能多的去嫖一些第 i - 1 棵树的结点,所以建树的参数要传两棵树的。根节点肯定不能嫖,新添加的数字 s[i] 所产生的结点也要用新的不能嫖。所谓能嫖的结点,就是沿用不符合当前数位的无关紧要的点。这些点虽然在现在用不到,但是在往后的时间里构建 trie 还是要在此基础上加工的。
插入和查询操作
//位数=log(1e7)(2)=23
int maxn = 300000;
int lst[maxn * 24], trie[maxn * 24], cnt, root[maxn];
void insert(int timepost, int val, int k, int p, int q){
if(k < 0){
lst[p] = timepost;
return;
}
int c = val >> k & 1;
if(q){
trie[p][c ^ 1] = trie[q][c ^ 1];
}
trie[p][c] = ++cnt;
insert(timepost, val, k - 1, trie[p][c], trie[q][c]);
lst[p] = max(lst[trie[p][0]], lst[trie[p][1]]);
}
int query(int timepost, int val, int k, int nownode){
if(k < 0){
return s[lst[nownode]] ^ val;
}
int c = val >> k & 1;
if(lst[nownode][c] >= timepost){
return query(timepost, val, k - 1, trie[nownode][c ^ 1]);
}
else return query(timepost, val, k -1, tire[nownode][c]);
}
重要的事情
-
有关数组大小
理论上来说,假设所有(n+m)棵trie都开满(2*{maxlen})个点,总的数组大小开trie[(maxn + maxm) * 2]
也就是2倍常数是保险的。这个题的话开到 500000 即可AC,但是无法确定所用点数最多的那个数据的 (n + m) 是几,目前不知道关于tire
树棵树的常数应该开到几。目前还是能开2倍就多开一点。 -
有关第0棵
trie
第0棵trie
是存在的,如果我们最后的答案是(a[1]oplus a[2]oplus ... oplus[n])的话,就需要进行(oplus s[0])操作,这意味着创造出(a[0] = s[0] = 0)是必须的。 -
有关
lst
注意lst
为0值时是指第0棵trie
,无定义结点的lst
应该赋值为-1。 -
有关
l-1
和r-1
这个在上面都说过了,注意一下。 -
有关
UB
Undefined Behavoir
比如这一句容易挂的语句s[++n]=s[n-1]^x
就很有问题。事实上,等号的优先级是最低的,这意味着等式的右边总是先于左边计算从而应该改成s[n]=s[(++n)-1]^x
。
后续
可持久化01tire是上古时期学的,昨天准备要写,今天学会并且写了,花了一下午和半个晚上的时间去过板子题。期间我发帖求助,炸出了一群巨神来帮我,好感动,又找到了曾经的感觉了,生活又美好了起来,又想码字了。
AC代码
#include<set>
#include<map>
#include<queue>
#include<cmath>
#include<ctime>
#include<stack>
#include<vector>
#include<cstdio>
#include<string>
#include<cstring>
#include<cstdlib>
#include<iomanip>
#include<iostream>
#include<algorithm>
#include<functional>
#define inf 0x3fffffff
#define ls p * 2
#define rs p * 2 + 1
#define fi first
#define se second
#define pb push_back
#define mp make_pair
using namespace std;
typedef long long ll;
typedef pair<int,int> pi;
typedef vector<int> vi;
typedef unsigned int ui;
int rd(){
int res = 0, fl = 1;
char c = getchar();
while(!isdigit(c)){
if(c == '-') fl = -1;
c = getchar();
}
while(isdigit(c)){
res = (res << 3) + (res << 1) + c - '0';
c = getchar();
}
return res * fl;
}
const int maxn = 500000;
int trie[maxn * 25][2], root[maxn], cnt, lst[maxn * 25];
string op;
int n, m, l, r;
int s[maxn], x, ans;
void insert(int timepost, int val, int k, int nownode, int lstnode){
if(k < 0){
lst[nownode] = timepost;
return;
}
int c = (val >> k) & 1;
trie[nownode][c] = ++cnt;
if(trie[lstnode][c ^ 1]){
trie[nownode][c ^ 1] = trie[lstnode][c ^ 1];
}
insert(timepost, val, k - 1, trie[nownode][c], trie[lstnode][c]);
lst[nownode] = max(lst[trie[nownode][0]], lst[trie[nownode][1]]);
return;
}
int query(int timepost, int val, int k, int nownode){
if(k < 0){
return s[lst[nownode]];
}
int c = val >> k & 1;
if(lst[trie[nownode][c ^ 1]] >= timepost){
return query(timepost, val, k - 1, trie[nownode][c ^ 1]);
}
else return query(timepost, val, k - 1, trie[nownode][c]);
}
int main(){
memset(lst, -1, sizeof(lst));
n = rd(); m = rd();
root[0] = ++cnt;
insert(0, 0, 24, root[0], root[0]);
for(int i = 1; i <= n; ++i){
x = rd();
s[i] = s[i - 1] ^ x;
root[i] = ++cnt;
insert(i, s[i], 24, root[i], root[i - 1]);
}
for(int i = 1; i <= m; ++i){
cin >> op;
if(op[0] == 'A'){
x = rd();
s[n] = s[(++n)-1] ^ x;
root[n] = ++cnt;
insert(n, s[n], 24, root[n], root[n - 1]);
}
else if(op[0] == 'Q'){
l = rd(); r = rd(); x = rd();
ans = s[n];
ans = ans ^ query(l - 1, x ^ s[n], 24, root[r - 1]);
ans = ans ^ x;
printf("%d
", ans);
}
}
return 0;
}