Description
给定 (n) 个整数 (a_1, a_2, dots, a_n, 0 le a_i le n) ,以及 (n) 个整数 (w_1, w_2, dots, w_n) 。称 (a_1, a_2, dots, a_n) 的 一个排列 (a_{p[1]}, a_{p[2]}, dots, a_{p[n]}) 为 (a_1, a_2, dots, a_n) 的一个合法排列,当且仅当该排列满足:对于任意的 (k) 和任意的 (j) ,如果 (p[k]) 等于 (a_{p[j]}) ,那么 (k<j) 。定义这个合法排列的权值为 (w_{p[1]} + 2w_{p[2]} + dots + nw_{p[n]}) 。
求出在所有合法排列中的最大权值。如果不存在合法排列,输出 (-1) 。
(1leq nleq 500000,0leq a_ileq n,1leq w_ileq 10^9) ,(sum w_ileq 1.5 imes 10^{13})
Solution
假如我们对于所有的 (i) , (a[i]) 和 (i) 间建一条边,显然这副图可能构成了一棵树。
如果不存在合法排列,当前仅当构成的图非树。
如何构成了树,那么原题的模型就变成了:给出一棵以 (0) 为根的有根树,需要为非 (0) 顶点标号 (1sim n) ,并且满足父亲比自己先标号。每个节点有点权,树的价值为点权乘标号的和。求树最大的价值。
一个显然的贪心是如果当前树中权值最小的点 (u) 没有父亲,那么我们当前一定是选 (u) 。
不过大部分不是这种情况。
考虑如果 (u) 有父亲,显然当他的父亲被选之后马上就会选 (u) ,也就是说父子间的编号一定是相邻的。我们可以将 (u) 的答案并在他的父亲中。
同样的,对于两个不同的“块”,也是如此。
考虑一个长度为 (l_1) 的序列 (A) 和一个长度为 (l_2) 的序列 (B) ,
序列前面已经安排好了 (loc) 个。考虑 (AB) 和 (BA) 两种合并后的序列的答案:
[W_{AB}=sum_{i=1}^{l_1}(i+loc)w_{A_i}+sum_{i=1}^{l_2}(i+loc+l_1)w_{B_i}]
[W_{BA}=sum_{i=1}^{l_2}(i+loc)w_{B_i}+sum_{i=1}^{l_1}(i+loc+l_2)w_{A_i}]
如果 (W_{AB}> W_{BA}Rightarrow frac{sum_{i=1}^{l_1}w_{A_i}}{l_1}<frac{sum_{i=1}^{l_2}w_{B_i}}{l_2})
也就是平均权值小的放前面答案会更优。
那么我们就可以用堆来维护这个东西。
不知道为什么写了个支持删除的堆只有 50 ,然而不删去而在取出堆顶时判断是否合法就对了...
Code
#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 500000+5;
int n, a[N], fa[N], sz[N]; ll w[N];
struct node {
int id; ll son, mom;
node (int _id = 0, ll _son = 0, ll _mom = 0) {id = _id, son = _son; mom = _mom; }
bool operator < (const node &b) const {return son*b.mom > b.son*mom; }
};
priority_queue<node>Q;
int find(int o) {return ~fa[o] ? fa[o] = find(fa[o]) : o; }
void work() {
memset(fa, -1, sizeof(fa));
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
if (find(a[i])^find(i)) fa[find(a[i])] = find(i);
else {puts("-1"); return; }
}
long long ans = 0; int loc = 0;
for (int i = 1; i <= n; i++) {
scanf("%lld", &w[i]); Q.push(node(i, w[i], 1));
sz[i] = 1; ans += w[i];
}
memset(fa, -1, sizeof(fa));
while (!Q.empty()) {
node t = Q.top(); Q.pop();
if (sz[t.id] != t.mom) continue;
if (find(a[t.id]) == 0) {
ans += w[t.id]*loc; fa[t.id] = 0; loc += sz[t.id];
}else {
int tmp = find(a[t.id]);
ans += w[t.id]*sz[tmp], fa[t.id] = tmp;
w[tmp] += w[t.id], sz[tmp] += sz[t.id];
Q.push(node(tmp, w[tmp], sz[tmp]));
}
}
printf("%lld
", ans);
}
int main() {work(); return 0; }