HDU-4747 Mex 线段树应用 Mex性质
题意
给定长度为(n)的数组(a),求
[sum sum mex(i,j)
]
其中(mex(i,j))表示区间(mex(a_i...a_j)的值)
[1leq n leq 2 imes 10^5\
1leq a_i leq 10^9
]
分析
此题我认为还是不太好想到的
首先如果只求一维,由于单调性,求(sum mex(1,i))是可以在(O(n))下完成的。
然后注意到第二维即(sum mex(2,i))该如何计算,这个时候(1)相当于没有了,1在这一维上产生的影响就是当前下一个等于(a[1])的元素之前的一段。后面的显然和原来的保持不变,这就让我们想到了用区间维护。
那么(a[1])会如何影响([2,next[a[1]] - 1])呢?
再次想到(mex)在“前缀”意义上的单调性,我们只需要把其中大于(a[1])的部分变为(a[1])即可,其他部分的(mex)并不会受影响
最后由于递推,不要忘记单点修改。
所以问题就转化成了
- 求出每个数的下一个等于它的数出现的位置
- 求出第一个大于等于(a[i])的位置
- 修改某一段区间的值
这些都可以用线段树实现,当然要注意一些细节,比如(lazy)标记应该设置(-1),否则(mx)会无法下传,以及下一个位置数组应该在最后加上(n + 1)点
代码
struct Tree {
int lazy;
int sum;
int mx;
int l, r;
};
int n;
Tree node[maxn << 2];
int a[maxn];
int mex[maxn];
int nxt[maxn];
void push_up(int i) {
node[i].sum = node[i << 1].sum + node[i << 1 | 1].sum;
node[i].mx = max(node[i << 1].mx, node[i << 1 | 1].mx);
}
void build(int i, int l, int r) {
node[i].l = l;
node[i].r = r;
if (l == r) {
node[i].sum = mex[l];
node[i].mx = mex[l];
return;
}
int mid = l + r >> 1;
build(i << 1, l, mid);
build(i << 1 | 1, mid + 1, r);
push_up(i);
}
void push_down(int i, int m) {
if (node[i].lazy >= 0) {
node[i << 1].lazy = node[i].lazy;
node[i << 1 | 1].lazy = node[i].lazy;
node[i << 1].sum = node[i].lazy * (m - (m >> 1));
node[i << 1 | 1].sum = node[i].lazy * (m >> 1);
node[i << 1].mx = node[i << 1 | 1].mx = node[i].lazy;
node[i].lazy = -1;
}
}
void update(int i, int l, int r, int val) {
if (node[i].l > r || node[i].r < l) return;
if (node[i].l >= l && node[i].r <= r) {
//bug;
node[i].lazy = val;
node[i].sum = (node[i].r - node[i].l + 1) * val;
node[i].mx = val;
//cout << node[i].mx << ' ' << i << '
';
return;
}
push_down(i, node[i].r - node[i].l + 1);
update(i << 1, l, r, val);
update(i << 1 | 1, l, r, val);
push_up(i);
}
int query(int i, int x) {
if (node[i].l == node[i].r) {
return node[i].l;
}
if (node[i << 1].mx > x) return query(i << 1, x);
else return query(i << 1 | 1, x);
}
signed main() {
while (scanf("%lld", &n)) {
if (!n) break;
for (int i = 1; i <= n; i++)
a[i] = readint(), nxt[i] = n + 1;
for (int i = 1; i < 4 * n; i++) {
node[i].l = node[i].r = node[i].sum = node[i].lazy = -1, node[i].mx = 0;
}
unordered_map<int, int> mp;
int cur = 0;
for (int i = 1; i <= n; i++) {
if (mp[a[i]]) nxt[mp[a[i]]] = i, mp[a[i]] = i;
else mp[a[i]] = i;
while (mp[cur]) cur++;
mex[i] = cur;
}
build(1, 1, n);
ll ans = 0;
for (int i = 1; i <= n; i++) {
ans += node[1].sum;
if (node[1].mx > a[i]) {
int l = query(1,a[i]);
int r = nxt[i] - 1;
if (l <= r) update(1, l, r, a[i]);
}
update(1, i, i, 0);
}
cout << ans << '
';
}
}