(large{题目链接})
(\)
题意:
给定一个长度为(n)的正整数序列,定义函数(f_{l,r})表示在下标在(left[l,r
ight])的子区间中不同整数的个数。
求:(sum limits^{n}_{l=1} sum limits ^{n}_{r=l}fleft( l,r
ight)^{2} left(mod 1e9 + 7
ight))
(1 leq n leq 10^6)
(\)
思路:
首先看到(10^9)的值域,而且关心的只是数值相等不相等,与具体值无关,先离散化一下。
我们枚举左端点(l),考虑当左端点为(l)的区间对答案的贡献,把这些贡献全部加在一起就是最终的答案。
那么题目就变成求 (sum limits _{i = 1} ^ {n} f(l,i)^2)。
因为(n)的范围是(10^6),显然要找到一种方法能够维护答案。
对于(left[l,n
ight])中出现过的数(x),设它在(left[l,n
ight])出现的最左位置为(pos_x)。记(t_i)为(f(l,i))的值。
考虑倒序循环(l),那么左端点由(l+1)变为(l)的时候,会发生两种事。
1.(t_l,t_{l+1},...,t_{pos_x-1})都加1。
2.(pos_x)变为l。
那么所需要解决的问题就变为了:
1.支持区间修改。
2.求区间的平方和。
可以用线段树维护。如果区间加上(k),那么平方和变为:
[left( a_{l}+k
ight) ^{2}+left( a_{l+1}+k
ight) ^{2}+ldots +left( a_{r}+k
ight) ^{2}
]
[= a^{2}_{l}+2ka_{l} + k^{2} + a^{2}_{l+1}+2ka_{l+1} + k^{2} +...+ a^{2}_{r}+2ka_{r} + k^{2}
]
[= left( a^{2}_{l}+a^{2}_{l+1}+ldots +a^{2}_{r}
ight) + 2k(a_{l}+a_{l+1}+ldots +a_{r}) + (r- l+ 1) imes k ^ 2
]
维护区间和和区间平方和即可。
(\)
代码:
#include <bits/stdc++.h>
#define ls (x << 1)
#define rs (x << 1 | 1)
using namespace std;
typedef long long ll;
const int N = 1e6 + 5;
const int p = 1e9 + 7;
int n, a[N], pos[N];
struct Node {
int id, val;
}b[N];
int read() {
int x = 0;
char c = getchar();
for (; !isdigit(c); c = getchar());
for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
return x;
}
bool cmp(Node x, Node y) { return x.val < y.val; }
struct Segment_tree {
int tl[N << 2], tr[N << 2];
ll t[N << 2], lz[N << 2], s[N << 2];
void build(int x, int l, int r) {
tl[x] = l, tr[x] = r;
if (l == r) return;
int mid = (l + r) >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
}
void up(int x) {
s[x] = s[ls] + s[rs];
if (s[x] > p) s[x] -= p;
t[x] = t[ls] + t[rs];
if (t[x] > p) t[x] -= p;
}
void down(int x) {
if (!lz[x]) return;
s[ls] = (s[ls] + 2 * lz[x] * t[ls] % p + (tr[ls] - tl[ls] + 1) * lz[x] * lz[x] % p) % p;
t[ls] = (t[ls] + (tr[ls] - tl[ls] + 1) * lz[x] % p) % p;
lz[ls] += lz[x];
s[rs] = (s[rs] + 2 * lz[x] * t[rs] % p + (tr[rs] - tl[rs] + 1) * lz[x] * lz[x] % p) % p;
t[rs] = (t[rs] + (tr[rs] - tl[rs] + 1) * lz[x] % p) % p;
lz[rs] += lz[x];
lz[x] = 0;
}
void update(int x, int l, int r, ll k) {
if (l <= tl[x] && r >= tr[x]) {
s[x] = (s[x] + 2 * k * t[x] % p + (tr[x] - tl[x] + 1) * k * k % p) % p;
t[x] = (t[x] + (tr[x] - tl[x] + 1) * k % p) % p;
lz[x] += k;
return;
}
down(x);
int mid = (tl[x] + tr[x]) >> 1;
if (l <= mid) update(ls, l, r, k);
if (r >= mid + 1) update(rs, l, r, k);
up(x);
}
ll query(int x, int l, int r) {
if (l <= tl[x] && r >= tr[x]) return s[x];
ll ret = 0;
int mid = (tl[x] + tr[x]) >> 1;
if (l <= mid) ret = query(ls, l, r);
if (r >= mid + 1) ret = (ret + query(rs, l, r)) % p;
return ret;
}
}T;
int main() {
n = read();
for (int i = 1; i <= n; ++i) b[i].id = i;
for (int i = 1; i <= n; ++i) b[i].val = read();
sort(b + 1, b + 1 + n, cmp);
int cnt = 0;
b[0].val = b[1].val - 1;
for (int i = 1; i <= n; ++i) b[i].val == b[i - 1].val ? a[b[i].id] = cnt : a[b[i].id] = ++cnt;
for (int i = 1; i <= cnt; ++i) pos[i] = n + 1;
T.build(1, 1, n);
ll ans = 0;
for (int i = n; i >= 1; --i) {
T.update(1, i, pos[a[i]] - 1, 1);
ans = (ans + T.query(1, i, n)) % p;
pos[a[i]] = i;
}
printf("%lld
", ans);
return 0;
}