题意:
(S)为一个自然数集合,定义函数(mex(S))为集合中没有出现的最小自然数。
给出一个长度为为(n)序列(a),设(S_{l,r})表示由(a_l sim a_r)构成的集合。
求:
[sumlimits_{1 leq l leq r leq n}mex(S_{l,r})
]
分析:
有这样一个事实:往集合(S)中任意加入一个元素,(mex(S))的值不会变小。
固定区间左端点来统计答案。
首先计算一下(mex{S_{1,1}},mex{S_{1,2}}, cdots, mex{S_{1,n}}),所以这是一个非递减的序列。
假设现在计算出(mex{S_{i,i}},mex{S_{i,i+1}}, cdots, mex{S_{i,n}}),考虑区间左端点向右移动。
相当于从这些集合中都删去了一个(a_i),如果有一个最小的(j>i)且(a_i=a_j),那么删去(a_i)对([j,n])这段区间没有影响,因为这段区间对应的集合没有改变。
然后考虑区间([i+1,j-1]),找到(mex)值大于(a_i)的区间,把它们的值都变为(a_i)。
因为集合中少了(a_i),所以根据(mex)函数的定义,(mex)值为(a_i)。
而且由于区间是非递减的,所以(mex)值大于(a_i)的区间也是连续的,用线段树维护即可。
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int maxn = 200000 + 10;
const int maxnode = maxn * 4;
int n;
int a[maxn], b[maxn], tot;
int pos[maxn], nxt[maxn];
bool vis[maxn];
int mex[maxn];
//Segment Tree
LL sum[maxnode];
int setv[maxnode], minv[maxnode], maxv[maxnode];
void pushup(int o) {
sum[o] = sum[o<<1] + sum[o<<1|1];
minv[o] = min(minv[o<<1], minv[o<<1|1]);
maxv[o] = max(maxv[o<<1], maxv[o<<1|1]);
}
void build(int o, int L, int R) {
if(L == R) {
sum[o] = minv[o] = maxv[o] = mex[L];
return;
}
int M = (L + R) / 2;
build(o<<1, L, M);
build(o<<1|1, M+1, R);
pushup(o);
}
void pushdown(int o, int L, int R) {
if(setv[o] != -1) {
int lc = o<<1, rc = o<<1|1;
setv[lc] = setv[rc] = setv[o];
minv[lc] = minv[rc] = setv[o];
maxv[lc] = maxv[rc] = setv[o];
int M = (L + R) / 2;
sum[lc] = (LL)setv[o] * (M - L + 1);
sum[rc] = (LL)setv[o] * (R - M);
setv[o] = -1;
}
}
void update(int o, int L, int R, int qL, int qR, int v) {
if(qL <= L && R <= qR && minv[o] > v) {
setv[o] = minv[o] = maxv[o] = v;
sum[o] = (LL)v * (R - L + 1);
return;
}
pushdown(o, L, R);
int M = (L + R) / 2;
if(qL <= M && maxv[o<<1] > v) update(o<<1, L, M, qL, qR, v);
if(qR > M && maxv[o<<1|1] > v) update(o<<1|1, M+1, R, qL, qR, v);
pushup(o);
}
LL query(int o, int L, int R, int qL, int qR) {
if(qL <= L && R <= qR) return sum[o];
pushdown(o, L, R);
int M = (L + R) / 2;
LL ans = 0;
if(qL <= M) ans += query(o<<1, L, M, qL, qR);
if(qR > M) ans += query(o<<1|1, M+1, R, qL, qR);
return ans;
}
int main()
{
while(scanf("%d", &n) == 1 && n) {
for(int i = 1; i <= n; i++) {
scanf("%d", a + i);
if(a[i] >= maxn) a[i] = maxn - 1;
b[i] = a[i];
}
sort(b + 1, b + 1 + n);
tot = unique(b + 1, b + 1 + n) - b - 1;
for(int i = 1; i <= n; i++)
a[i] = lower_bound(b + 1, b + 1 + tot, a[i]) - b;
for(int i = 1; i <= tot; i++) pos[i] = n + 1;
for(int i = n; i > 0; i--) {
nxt[i] = pos[a[i]];
pos[a[i]] = i;
}
memset(vis, false, sizeof(vis));
int p = 0;
for(int i = 1; i <= n; i++) {
vis[b[a[i]]] = true;
while(vis[p]) p++;
mex[i] = p;
}
memset(setv, -1, sizeof(setv));
build(1, 1, n);
LL ans = sum[1];
for(int i = 2; i <= n; i++) {
int j = nxt[i - 1];
if(j > i) update(1, 1, n, i, j - 1, b[a[i-1]]);
ans += query(1, 1, n, i, n);
}
printf("%lld
", ans);
}
return 0;
}