一眼线段树...显然,我们可以考虑最后所留下的区间,那显然这个区间中应当不能存在任何与区间外相同的颜色。这里的转化也是很常用的,我们用 (nxt[i]) 表示与 (i) 颜色相同的下一个位置在哪里, (last[i]) 表示与 (i) 颜色相同的上一个位置在哪里;那么一个区间 (i, j) 是满足要求的当且仅当 (min(last[k]) >= i, max(nxt[k]) <= j (i <= k <= j)) 。我们可以用单调栈处理出 (lim[i]) 记录下第一个 (last[k] < i (k >= i)) 的 (k)。那么我们可以发现以 (i) 为区间左端点的区间右端点一定在 ([i, lim[i] - 1]) 之间。我们考虑如何满足有关 (j) 的限制。
我们由于左端点是从右往左的扫描线,所以考虑答案的时候也只需要考虑当前 (>= i) 的 (j) (满足当前考虑的限制均是在 ([i, n]) 范围内的限制)。由于 (max(nxt[k]) <= j) ,所以对于所有的 (k) 与 (nxt[k]) 而言,它所影响到的右端点就是在 ([k, nxt[k] - 1]) 范围内的节点,让他们都无法与之后所有的左端点匹配成为合法的区间。我们只需要维护一棵每碰到限制就区间赋值为 (0)的线段树 ,并查询区间内的总和就可以计算出合法的区间总数啦。
#include <bits/stdc++.h> using namespace std; #define maxn 1000000 #define LL long long #define INF 99999999999LL int n, a[maxn], rec[maxn], mark[maxn]; int top, S[maxn], nxt[maxn], last[maxn], lim[maxn]; int sum[maxn]; LL ans; int read() { int x = 0, k = 1; char c; c = getchar(); while(c < '0' || c > '9') { if(c == '-') k = -1; c = getchar(); } while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); return x * k; } void Push_down(int p) { if(!mark[p]) return; int ls = p << 1, rs = p << 1 | 1; mark[ls] = mark[rs] = 1; sum[ls] = sum[rs] = 0; mark[p] = 0; } int Query(int p, int l, int r, int L, int R) { if(L > R) return 0; if(L > r || R < l) return 0; if(L <= l && R >= r) return sum[p]; int mid = (l + r) >> 1; Push_down(p); return Query(p << 1, l, mid, L, R) + Query(p << 1 | 1, mid + 1, r, L, R); } void Update(int p, int l, int r, int L, int R) { if(L > r || R < l) return; if(L <= l && R >= r) { mark[p] = 1, sum[p] = 0; return; } int mid = (l + r) >> 1; Push_down(p); Update(p << 1, l, mid, L, R), Update(p << 1 | 1, mid + 1, r, L, R); sum[p] = sum[p << 1] + sum[p << 1 | 1]; } void Build(int p, int l, int r) { if(l == r) { sum[p] = 1; return; } int mid = (l + r) >> 1; Build(p << 1, l, mid), Build(p << 1 | 1, mid + 1, r); sum[p] = sum[p << 1] + sum[p << 1 | 1]; } void init() { memset(mark, 0, sizeof(mark)); memset(rec, 0, sizeof(rec)); ans = 0; } signed main() { int T = read(); while(T --) { init(); n = read(); for(int i = 1; i <= n; i ++) a[i] = read(); for(int i = 1; i <= n; i ++) { last[i] = nxt[i] = 0; last[i] = rec[a[i]], nxt[rec[a[i]]] = i; rec[a[i]] = i; } for(int i = 1; i <= n; i ++) if(!last[i]) last[i] = n + 2; for(int i = n; i >= 1; i --) { while(top >= 1 && last[S[top]] > last[i]) top --; S[++ top] = i; while(top >= 1 && last[S[top]] >= i) top --; if(!top) lim[i] = n + 1; else lim[i] = S[top]; } Build(1, 1, n); for(int i = n; i >= 1; i --) { if(nxt[i]) Update(1, 1, n, i, nxt[i] - 1); ans += (LL) Query(1, 1, n, i, lim[i] - 1); } printf("%lld ", ans); } return 0; }