一开始想到点分治, 其实不是很好搞.
因为分治每次是计算的过某个点的答案, 所以我们也可以按一定的顺序计算贡献.
因为题目是按照最大值最小值计算贡献的, 所以按照从小到大的方式计算贡献.
先求最大值, 然后一起减去最小值贡献就可以了.
所以我们从小到大排序后, 对于每个相邻联通块之间的链, 这个点就一定为最大值.
所以直接算贡献就可以了.
考虑这一道题目的套路, 同样是统计树上路徑, 点分治一定要按照分治结构来, 每次统计过重心的路徑. 因为这题的最大值跟重心
相关性不强, 所以点分治很难写.
其实树上路径统计总的来说都是算贡献. 关键是按照什么为标准计算贡献, 点分治保证联通块大小是O(logn)
递减的, 这样就可以直接暴力艹.
这种方法的好处主要在于可以按照一定的顺序算贡献, 当可以简单的发现点或边之间可以构造出一定顺序的时候就可以采用.
Code
#include<bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for(int i = (a), i##_end_ = (b); i <= i##_end_; ++i)
#define drep(i, a, b) for(int i = (a), i##_end_ = (b); i >= i##_end_; --i)
#define clar(a, b) memset((a), (b), sizeof(a))
#define debug(...) fprintf(stderr, __VA_ARGS__)
typedef long long LL;
typedef long double LD;
int read() {
char ch = getchar();
int x = 0, flag = 1;
for (;!isdigit(ch); ch = getchar()) if (ch == '-') flag *= -1;
for (;isdigit(ch); ch = getchar()) x = x * 10 + ch - 48;
return x * flag;
}
void write(int x) {
if (x < 0) putchar('-'), x = -x;
if (x >= 10) write(x / 10);
putchar(x % 10 + 48);
}
const int Maxn = 1e6 + 9;
struct edge {
int to, nxt;
}g[Maxn << 1];
int head[Maxn], e, fa[Maxn];
LL ans, a[Maxn], n;
void add(int u, int v) {
g[++e] = (edge){v, head[u]}, head[u] = e;
}
void init() {
clar(head, -1);
n = read();
rep (i, 1, n) a[i] = read();
rep (i, 1, n - 1) {
int u = read(), v = read();
add(u, v), add(v, u);
}
}
struct node {
int Id, Val;
int operator < (const node b) const {
return Val < b.Val;
}
}s[Maxn];
int find(int u) { return fa[u] ^ u ? (fa[u] = find(fa[u])) : u; }
int vis[Maxn], size[Maxn];
void solve() {
rep (i, 1, n) {
s[i] = (node){i, a[i]};
fa[i] = i; vis[i] = 0; size[i] = 1;
}
sort(s + 1, s + n + 1);
rep (i, 1, n) {
int u = s[i].Id; vis[u] = 1;
for (int j = head[u]; ~j; j = g[j].nxt) {
int v = g[j].to;
if (vis[v]) {
int l = find(v);
ans += 1ll * size[l] * size[u] * s[i].Val;
size[u] += size[l], fa[l] = u;
}
}
}
rep (i, 1, n) s[i].Val = -s[i].Val;
rep (i, 1, n) fa[i] = i, vis[i] = 0, size[i] = 1;
sort(s + 1, s + n + 1);
rep (i, 1, n) {
int u = s[i].Id; vis[u] = 1;
for (int j = head[u]; ~j; j = g[j].nxt) {
int v = g[j].to;
if (vis[v]) {
int l = find(v);
ans += 1ll * size[l] * size[u] * s[i].Val;
size[u] += size[l], fa[l] = u;
}
}
}
cout << ans << endl;
}
int main() {
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
init();
solve();
#ifdef Qrsikno
debug("
Running time: %.3lf(s)
", clock() * 1.0 / CLOCKS_PER_SEC);
#endif
return 0;
}