[题目链接]
https://www.lydsy.com/JudgeOnline/problem.php?id=5289
[算法]
题目中的限制条件可看作是 : 第i个数必须排在所有权值为i的数前面
那么 , 我们枚举每一个数 , 向所有权值为当前枚举的数的下标的数连边 , 不难发现 , 若无解 , 则说明图中存在环 , 我们可以通过拓扑排序判断图中是否存在环
否则 , 仔细观察后发现这其实是由若干棵树组成的森林 , 我们建立(n + 1)号节点 , 将该节点的权值设为0,并向森林中每一棵树的根节点连边
问题就转化为了 : 有一棵树 , 你需要给这棵树染色 , 除根节点外 , 染了父亲节点才能染子节点 , 每次染色花费的代价为 : i * Wi , 其中i代表的是第几次染色 , 要求最大化染色的代价
我们可以发现一个性质 : 权值最小的节点一定在其父节点被染色后立即染色
根据这个性质 , 我们不妨找出权值最小的节点 , 将其与父节点合并 , 合并后的权值为它们权值和的平均值
重复以上过程N次 , 就得到了答案
具体实现时 , 我们需要一个堆维护权值最小的节点 , 此外 , 还需用并查集快速找出某个点合并后的父亲
时间复杂度 : O(NlogN)
[代码]
#include<bits/stdc++.h> using namespace std; #define MAXN 500010 typedef long long ll; int n,size; int fa[MAXN],deg[MAXN],hp[MAXN],pos[MAXN],s[MAXN]; ll ans; vector< int > e[MAXN],ver[MAXN]; pair<ll,int> a[MAXN]; template <typename T> inline void read(T &x) { ll f = 1; x = 0; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0'; x *= f; } inline bool cmp(int x,int y) { return a[x].first * a[y].second < a[x].second * a[y].first; } inline int get_root(int x) { if (s[x] == x) return x; return s[x] = get_root(s[x]); } inline bool topsort() { int l,r; static int q[MAXN]; l = 1; r = 0; for (register int i = 1; i <= n; i++) { if (!deg[i]) q[++r] = i; } while (l <= r) { int u = q[l]; l++; for (register int i = 0; i < (int)e[u].size(); i++) { int v = e[u][i]; if ((--deg[v]) == 0) q[++r] = v; } } for (register int i = 1; i <= n; i++) { if (deg[i]) return false; } return true; } inline void up(int x) { int fa = x >> 1; if (x == 1) return; if (cmp(hp[x],hp[fa])) { swap(pos[hp[fa]],pos[hp[x]]); swap(hp[x],hp[fa]); up(fa); } } inline void down(int x) { int son = x << 1; if (son > size) return; if (son + 1 <= size && cmp(hp[son + 1],hp[son])) son++; if (cmp(hp[son],hp[x])) { swap(pos[hp[son]],pos[hp[x]]); swap(hp[son],hp[x]); down(son); } } inline void push(int a) { hp[++size] = a; pos[a] = size; up(size); } inline void pop() { swap(pos[hp[1]],pos[hp[size]]); swap(hp[1],hp[size--]); down(1); } inline int top() { return hp[1]; } inline bool is_empty() { return size == 0; } int main() { read(n); for (register int i = 1; i <= n; i++) { int x; read(x); ver[x].push_back(i); } for (register int i = 1; i <= n; i++) { ll x; read(x); a[i] = make_pair(x,1); } for (register int i = 1; i <= n; i++) { for (register int j = 0; j < (int)ver[i].size(); j++) { e[i].push_back(ver[i][j]); fa[ver[i][j]] = i; deg[ver[i][j]]++; } } for (register int i = 1; i <= n; i++) { if (!fa[i]) { e[n + 1].push_back(i); deg[i]++; fa[i] = n + 1; } } a[++n] = make_pair(0,1); for (register int i = 1; i <= n; i++) s[i] = i; if (!topsort()) { printf("-1 "); return 0; } for (register int i = 1; i < n; i++) push(i); for (register int i = 1; i < n; i++) { int v = top() , f = get_root(fa[v]); pop(); ans += a[v].first * a[f].second; a[f].first += a[v].first; a[f].second += a[v].second; s[v] = f; if (f != n) up(pos[f]); } printf("%lld ",ans); return 0; }