【LG4437】[HNOI/AHOI2018]排列
题面
题解
题面里这个毒瘤的东西我们转化一下:
对于(forall k,j),若(p_k=a_{p_j}),则(k<j)。
也就是说若(y=a_x),则(y)排在(x)前面,
那么我们在原数组编号中(a_x)向(x)连边可以表示出这种拓扑关系。
那么我们连玩边后肯定是以(0)为根的一颗有根树,否则一定会形成一个环,无解。
贪心地想一下,对于权值最小的点,我们肯定让它尽量往前选,那么在它父亲选完后,我们一定会选它,所以我们可以考虑把它的权值并到它父亲上。
这样子的话,我们每个点就变成了一个序列,
考虑两个序列(a,b)的合并方式决定最优答案(当前已经到了第(i)位):
[W_{ab}=sum_{j=1}^{m_1}(i+j)w_{a_j}+sum_{j=1}^{m_2}(i+j+m_1)w_{b_j}\
W_{ba}=sum_{j=1}^{m_2}(i+j)w_{b_j}+sum_{j=1}^{m_1}(i+j+m_2)w_{a_j}\
W_{ab}-W_{ba}=m_1W_b-m_2W_a
]
那么如果(W_{ab}>W_{ba})则(frac{W_a}{m_1}<frac{W_b}{m_2}),也就是平均数小的放前面。
具体实现详见代码。
代码
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
inline int gi() {
register int data = 0, w = 1;
register char ch = 0;
while (!isdigit(ch) && ch != '-') ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (isdigit(ch)) data = 10 * data + ch - '0', ch = getchar();
return w * data;
}
const int MAX_N = 5e5 + 5;
struct Graph { int next, to; } e[MAX_N << 1]; int fir[MAX_N], e_cnt;
void clearGraph() { memset(fir, -1, sizeof(fir)); e_cnt = 0; }
void Add_Edge(int u, int v) { e[e_cnt] = (Graph){fir[u], v}; fir[u] = e_cnt++; }
bool vis[MAX_N];
int N, tot, pa[MAX_N], fa[MAX_N], size[MAX_N];
long long w[MAX_N];
void dfs(int x) {
vis[x] = 1, ++tot;
for (int i = fir[x]; ~i; i = e[i].next) {
int v = e[i].to;
if (vis[v]) { puts("-1"); exit(0); }
else dfs(v);
}
}
int getf(int x) { return pa[x] == x ? x : pa[x] = getf(pa[x]); }
struct Node { int u, sz; long long w; } ;
bool operator < (const Node &l, const Node &r) { return l.w * r.sz > r.w * l.sz; }
struct Heap{
Node h[MAX_N]; int cur;
Node top() { return h[1]; }
void push(const Node &x) { h[++cur] = x; push_heap(&h[1], &h[cur + 1]); }
void pop() { pop_heap(&h[1], &h[cur + 1]); --cur; }
bool empty() { return cur == 0; }
} que;
int main () {
#ifndef ONLINE_JUDGE
freopen("cpp.in", "r", stdin);
#endif
clearGraph();
N = gi();
for (int i = 1; i <= N; i++) fa[i] = gi(), Add_Edge(fa[i], i);
for (int i = 1; i <= N; i++) w[i] = gi();
dfs(0); if (tot <= N) return puts("-1") & 0;
for (int i = 0; i <= N; i++) pa[i] = i, size[i] = 1;
for (int i = 1; i <= N; i++) que.push((Node){i, 1, w[i]});
long long ans = 0;
while (!que.empty()) {
Node p = que.top(); que.pop();
int u = getf(p.u);
if (size[u] != p.sz) continue;
int f = getf(fa[u]); pa[u] = f;
ans += w[u] * size[f], w[f] += w[u], size[f] += size[u];
if (f) que.push((Node){f, size[f], w[f]});
}
printf("%lld
", ans);
return 0;
}