https://daniu.luogu.org/problem/show?pid=2042
一道伸展树维护数列的很悲伤的题目,共要维护两个标记和两个数列信息,为了维护MAX-SUM还要维护从左端开始的数列的最大和及到右端结束的数列的最大和。
按照伸展树的套路,给数列左右两边加上不存在的边界节点,给每个子树的空儿子指向哨兵节点。
维护最大子数列和
题目说的子数列其实要求至少包含一个元素,这就要很恶心的维护方法。
(其实让max_sum可以不含元素也能过90%)
每个节点定义max_sum:该节点的最大数列和(至少包含一个元素)
max_lsum:该节点的从左端开始的最大数列和(可以不包含元素)
max_rsum:该节点的到右端结束的最大数列和(可以不包含元素)
按照分冶法,max_sum=max{左儿子max_sum,右儿子max_sum,左儿子max_rsum+该节点的值+右儿子max_lsum}。
如果它和它的左右儿子都是普通节点,这个转移保证至少有一个元素。
如果它是普通节点或边界节点,它的左或右儿子是哨兵节点,则左儿子max_sum或右儿子max_sum是不可取的。故令哨兵节点的max_sum=-inf。
如果它是边界节点,它必定至多有一个儿子,令它的max_sum等于它的唯一儿子的max_sum,max_lsum与max_rsum同理。
覆盖子数列和翻转子数列
每个节点定义两个标记replaced和reversed。
replaced:这个节点及它的所有后代都应该修改为一个特定的值,但实际上只有这个节点的值已经修改。
reversed:这个节点及它的所有后代都应该交换左右子树(max_lsum和max_rsum也应该跟着交换),但实际上只有这个节点的左右子树已经交换。
可见这两个标记是互斥的,且replaced标记的优先级显然大于reversed标记。
打标记的时候注意维护每个结点的标记至多有一个就可以了。
350行的不压行代码,6.33KB,调了近8小时交了差不多二十遍才AC:
#include <algorithm> #include <cctype> #include <iostream> #include <string> using namespace std; void getstr(string &s) { int c; s = ""; while (!isalpha(c = getchar())) { if (c == EOF) return; } do s += (char)c; while (isalpha(c = getchar()) || c == '-'); } void getint(int &x) { int c; bool flag = false; x = 0; while (!isdigit(c = getchar())) { if (c == EOF) return; if (c == '-') flag = true; } do x = x * 10 + c - '0'; while (isdigit(c = getchar())); if (flag) x = -x; } namespace splay { const int inf = 0x7fffffff; enum direction { l = 0, r }; struct node; node *nil = 0, *l_edge, *r_edge; struct node { int val, size; node *ch[2]; int sum; int max_sum, max_lsum, max_rsum; // max_sum 定义为最少包含一个元素的最大子数列和 // max_lsum 定义为从左端开始的可以不包含元素的最大子数列和 // max_lsum 定义为到右端结束的可以不包含元素的最大子数列和 bool replaced, reversed; // 当replaced为true,表示它的所有后代的val应该与这个节点的val相同,但实际上后代节点并没有更新 // 当reversed为true,表示它已经交换了左右节点和左右最大值,且它的所有后代都应该交换左右子树和左右最大值,但实际上后代节点并没有更新 node(int v) : val(v), size(1), sum(v), replaced(false), reversed(false) { ch[l] = ch[r] = nil; if (v >= 0) max_sum = max_lsum = max_rsum = sum; else { max_sum = v; max_lsum = max_rsum = 0; } } int cmp(int k) { if (k == ch[l]->size + 1 || this == nil) return -1; else return k <= ch[l]->size ? l : r; } void reverse() { if (!replaced) { reversed ^= 1; swap(ch[l], ch[r]); swap(max_lsum, max_rsum); } } void replace(int v) { reversed = false; replaced = true; val = v; sum = v * size; if (v > 0) max_sum = max_lsum = max_rsum = sum; else { max_sum = v; // 由于子数列要求至少有一个元素,故当 val < 0 // ,只有一个元素时和最大 max_lsum = max_rsum = 0; } } void push_down() { if (replaced) { ch[l]->replace(val); ch[r]->replace(val); replaced = false; } else if (reversed) { ch[l]->reverse(); ch[r]->reverse(); reversed = false; } } void pull_up() { if (this != nil) { size = ch[l]->size + ch[r]->size + 1; if (!replaced) sum = ch[l]->sum + ch[r]->sum + val; else sum = val * size; if (this != l_edge && this != r_edge) { max_sum = max(ch[l]->max_rsum + val + ch[r]->max_lsum, max(ch[l]->max_sum, ch[r]->max_sum)); // 更新后 max_sum 至少包含一个元素 max_lsum = max( ch[l]->max_lsum, ch[l]->sum + val + ch[r]->max_lsum); // 更新后 max_lsum / max_rsum 可以不包含元素 max_rsum = max(ch[r]->max_rsum, ch[l]->max_rsum + val + ch[r]->sum); } else if (this == l_edge) // 注意特判左右边界节点 { // 若不特判,当左边界节点为根且整个数列的从左开始的最大值为0时 // 就会出现 max_sum = ch[l]->max_rsum + val + ch[r]->max_lsum // 即 max_sum = 0,这显然不合法 max_sum = ch[r]->max_sum; max_lsum = ch[r]->max_lsum; max_rsum = ch[r]->max_rsum; } else { // 右边界同理 max_sum = ch[l]->max_sum; max_lsum = ch[l]->max_lsum; max_rsum = ch[l]->max_rsum; } } } void remove() { if (this != nil) { ch[l]->remove(); ch[r]->remove(); delete this; } } } * root; void init() { if (!nil) nil = new node(0); nil->size = 0; nil->ch[l] = nil->ch[r] = nil; nil->max_sum = -inf; l_edge = new node(0), r_edge = new node(0); l_edge->max_sum = -inf; r_edge->max_sum = -inf; root = nil; } void rotate(node *&t, int d) { t->push_down(); t->ch[l]->push_down(); t->ch[r]->push_down(); node *k = t->ch[d ^ 1]; t->ch[d ^ 1] = k->ch[d]; k->ch[d] = t; t->pull_up(); k->pull_up(); t = k; } void splay(node *&t, int k) { t->push_down(); int d = t->cmp(k); if (d == r) k = k - t->ch[l]->size - 1; if (d != -1) { t->ch[d]->push_down(); int d2 = t->ch[d]->cmp(k); int k2 = (d2 == r) ? k - t->ch[d]->ch[l]->size - 1 : k; if (d2 != -1) { splay(t->ch[d]->ch[d2], k2); if (d == d2) { rotate(t, d ^ 1); rotate(t, d ^ 1); } else { rotate(t->ch[d], d2 ^ 1); rotate(t, d ^ 1); } } else rotate(t, d ^ 1); } } void join(node *&t1, node *&t2) { if (t1 == nil) swap(t1, t2); splay(t1, t1->size); t1->ch[r] = t2; t2 = nil; t1->pull_up(); } node *split(node *&t, int k) { if (k == 0) { node *subtree = t; t = nil; return subtree; } splay(t, k); node *subtree = t->ch[r]; t->ch[r] = nil; t->pull_up(); return subtree; } node *build_tree(int *p, int n) { if (n == 0) return nil; node *fa; node *ch = new node(p[1]); for (int i = 2; i <= n; i++) { fa = new node(p[i]); fa->ch[l] = ch; fa->pull_up(); ch = fa; } return fa; } node *select(int p, int tot) { int ln = p, rn = ln + tot - 1; splay(root, rn + 1); splay(root->ch[l], ln - 1); return root->ch[l]->ch[r]; } } int n, m; int num[500005]; int main() { using namespace splay; ios::sync_with_stdio(false); getint(n); getint(m); for (int i = 1; i <= n; i++) getint(num[i]); init(); node *t1, *t2; // tmp root = l_edge; t1 = build_tree(num, n); join(root, t1); t1 = r_edge; join(root, t1); string opt; int posi, tot, c; while (m--) { getstr(opt); switch (opt[0]) { case 'I': // INSERT getint(posi); getint(tot); posi++; for (int i = 1; i <= tot; i++) getint(num[i]); t1 = build_tree(num, tot); t2 = split(root, posi); join(root, t1); join(root, t2); break; case 'D': // DELETE getint(posi); getint(tot); posi++; t1 = split(root, posi - 1); t2 = split(t1, tot); join(root, t2); t1->remove(); break; case 'R': // REVERSE getint(posi); getint(tot); posi++; t1 = select(posi, tot); t1->reverse(); root->ch[l]->pull_up(); root->pull_up(); break; case 'G': // GET-SUM getint(posi); getint(tot); posi++; t1 = select(posi, tot); cout << t1->sum << endl; break; case 'M': if (opt[2] == 'K') // MAKE_SAME { getint(posi); getint(tot); getint(c); posi++; t1 = select(posi, tot); t1->replace(c); root->ch[l]->pull_up(); root->pull_up(); } else // MAX_SUM cout << root->max_sum << endl; break; } } return 0; }