感觉这种题没啥营养, 排个序算算贡献就好啦。
#include<bits/stdc++.h> #define LL long long #define fi first #define se second #define mk make_pair #define PLL pair<LL, LL> #define PLI pair<LL, int> #define PII pair<int, int> #define SZ(x) ((int)x.size()) #define ull unsigned long long using namespace std; const int N = 1e6 + 7; const int inf = 0x3f3f3f3f; const LL INF = 0x3f3f3f3f3f3f3f3f; const int mod = 1e9 + 7; const double eps = 1e-8; const double PI = acos(-1); int n, id[N], a[N], fa[N], cnt[N]; LL ans; vector<int> G[N]; int getRoot(int x) { return fa[x] == x ? x : fa[x] = getRoot(fa[x]); } bool cmp(const int& x, const int& y) { return a[x] < a[y]; } int main() { scanf("%d", &n); for(int i = 1; i <= n; i++) { scanf("%d", &a[i]); id[i] = i; } for(int i = 2; i <= n; i++) { int u, v; scanf("%d%d", &u, &v); G[u].push_back(v); G[v].push_back(u); } sort(id + 1, id + 1 + n, cmp); for(int i = 1; i <= n; i++) fa[i] = i, cnt[i] = 1; for(int i = 1; i <= n; i++) { int u = id[i]; for(auto& v : G[u]) { if(a[u] > a[v] || (a[u] == a[v] && u > v)) { int x = getRoot(u); int y = getRoot(v); ans += 1ll * a[u] * cnt[x] * cnt[y]; fa[y] = x; cnt[x] += cnt[y]; } } } reverse(id + 1, id + 1 + n); for(int i = 1; i <= n; i++) fa[i] = i, cnt[i] = 1; for(int i = 1; i <= n; i++) { int u = id[i]; for(auto& v : G[u]) { if(a[u] < a[v] || (a[u] == a[v] && u > v)) { int x = getRoot(u); int y = getRoot(v); ans -= 1ll * a[u] * cnt[x] * cnt[y]; fa[y] = x; cnt[x] += cnt[y]; } } } printf("%lld ", ans); return 0; } /* */