很多人是根号分治,来一发 polylog 的做法。
线段树/树状数组好题啊。
因为懒下面复杂度均视作 (a=n^{mathcal{O}(1)})。
首先考虑求出 (p) 的差分数组,也就是求出:
[p'_k=sum_{1leq ileq n} a_imod a_j+a_jmod a_i
]
分类讨论一下:
- (a_i mod a_j, a_i<a_j)
- (a_j mod a_i, a_i>a_j)
直接树状数组,前者贡献均为 (a_i),记录一下前面大于 (a_i) 的有多少个。后者贡献均为 (a_j),记录一下前面小于 (a_i) 的 (a_j) 的和。
- (a_i mod a_j, a_i>a_j)
维护一个 (cost_i),代表考虑到当前位,前面的数对 (i) 的贡献是多少。
发现每次是在 (a_i) 的倍数中间的一段都会有个等差数列加上去,直接用线段树维护差分数组一段一段暴力改,由于 (a_i) 互不相同,所以等差数列个数是调和级数 (mathcal{O}(nlog n)) 的,总复杂度是 (mathcal{O}(nlog^2 n)) 的。
- (a_j mod a_i, a_i<a_j)
考虑枚举 (a_i) 的倍数分出的一个个段,统计这个段里 (a_j) 对 (a_i) 的贡献。
那么答案就是
[egin{aligned}
&sum_{lleq i< r}(i-l)cdot a_i
\
=&sum_{lleq i< r}icdot a_i-lcdot a_i
end{aligned}
]
只需要维护 (icdot a_i) 和 (a_i) 在值域维的区间和即可。枚举 (a_i) 倍数的次数是调和级数 (mathcal{O}(nlog n)) 的,用树状数组维护,总复杂度是 (mathcal{O}(nlog^2n)) 的。
(mathcal{Code})
#include<iostream>
#include<cstdio>
#include<vector>
#include<algorithm>
typedef long long ll;
template <typename T> T Max(T x, T y) { return x > y ? x : y; }
template <typename T> T Min(T x, T y) { return x < y ? x : y; }
template <typename T>
T& read(T& r) {
r = 0; bool w = 0; char ch = getchar();
while(ch < '0' || ch > '9') w = ch == '-' ? 1 : 0, ch = getchar();
while(ch >= '0' && ch <= '9') r = r * 10 + (ch ^ 48), ch = getchar();
return r = w ? -r : r;
}
inline int lowbit(int x) { return x & (-x); }
const int N = 300010;
int n, a[N];
ll p[N];
struct BIT {
int m;
ll tree[N];
void init(int x) {
m = x;
for(int i = 0; i <= m; ++i) tree[i] = 0;
}
void modify(int x, ll v) { for(; x <= m; x += lowbit(x)) tree[x] += v; }
ll querysum(int x) { ll sumq = 0; for(; x; x -= lowbit(x)) sumq += tree[x]; return sumq; }
ll query(int x, int y) { return querysum(y)-querysum(x-1); }
}bit1, bit2, tr2, tr3;
struct segment_tree {
#define ls tree[x].lson
#define rs tree[x].rson
#define tl tree[x].l
#define tr tree[x].r
int trnt;
struct SGT {
int l, r, lson, rson;
ll sum, tag;
}tree[N << 1];
inline void pushup(int x) { tree[x].sum = tree[ls].sum + tree[rs].sum; }
inline void pushdown(int x) {
if(tree[x].tag) {
int p = tree[x].tag;
tree[ls].sum += (tree[ls].r - tree[ls].l + 1) * p;
tree[rs].sum += (tree[rs].r - tree[rs].l + 1) * p;
tree[ls].tag += p;
tree[rs].tag += p;
tree[x].tag = 0;
}
}
int build(int l, int r) {
int x = ++trnt; tree[x].sum = 0;
tl = l; tr = r;
if(l == r) return x;
ls = build(l, (l+r)>>1); rs = build(tree[ls].r+1, r);
pushup(x); return x;
}
void modify(int x, int l, int r, ll v) {
if(tl >= l && tr <= r) {
tree[x].sum += (tree[x].r - tree[x].l + 1) * v;
tree[x].tag += v;
return ;
}
int mid = (tree[x].l + tree[x].r) >> 1;
pushdown(x);
if(mid >= l) modify(ls, l, r, v);
if(mid < r) modify(rs, l, r , v);
pushup(x);
}
ll query(int x, int l, int r) {
if(tl >= l && tr <= r) return tree[x].sum;
int mid = (tree[x].l + tree[x].r) >> 1; ll sumq = 0;
pushdown(x);
if(mid >= l) sumq += query(ls, l, r);
if(mid < r) sumq += query(rs, l, r);
pushup(x);
return sumq;
}
}tr1;
#undef ls
#undef rs
#undef tl
#undef tr
int mx = 0;
signed main() {
read(n);
for(int i = 1; i <= n; ++i) read(a[i]), mx = Max(mx, a[i]);
bit1.init(mx); bit2.init(mx); tr2.init(mx); tr3.init(mx);
for(int i = 1; i <= n; ++i) {
p[i] += 1ll * a[i] * bit1.query(a[i]+1, mx);
p[i] += bit2.query(1, a[i]-1);
bit1.modify(a[i], 1);
bit2.modify(a[i], a[i]);
}
tr1.build(1, mx);
for(int i = 1; i <= n; ++i) {
p[i] += tr1.query(1, 1, a[i]);
for(int j = 1; j * a[i] <= mx; ++j) {
p[i] -= j*a[i] * tr2.query(j*a[i], Min(mx, (j+1)*a[i]-1));
p[i] += tr3.query(j*a[i], Min(mx, (j+1)*a[i]-1));
}
for(int j = 1; j * a[i] + 1 <= mx; ++j) {
tr1.modify(1, j*a[i]+1, Min(mx, (j+1)*a[i]-1), 1);
if((j+1)*a[i] <= mx)tr1.modify(1, (j+1)*a[i], (j+1)*a[i], -(Min(mx, (j+1)*a[i]-1) - (j*a[i]+1) + 1));
}
if(a[i] == 1) continue ;
tr2.modify(a[i], 1);
tr3.modify(a[i], a[i]);
}
for(int i = 1; i <= n; ++i) p[i] += p[i-1], printf("%lld ", p[i]);
return 0;
}