【题目链接】
https://www.lydsy.com/JudgeOnline/problem.php?id=5289
https://www.luogu.org/problemnew/show/P4437
【题解】
限制条件可以归纳为:若要在之前被选择。
那么我们将限制关系连边,无解的条件当且仅当出现环。
否则一定是棵树(因为每个点入度都为1)。
现在问题转换如下:有一棵树,每个点有权值,现在要再给每个点分配一个不重复的权值,满足在此基础上最大化。
先来考虑一个简单的情况:若且是最小的儿子。那么选了之后下一个一定会选。那么我们可以把与合并。
这个结论对于一个连通块也是对的,所以块与块之间也可以比较大小,即比较的平均值。所以我们可以每次找出权值最小的联通快,将其与他的父亲合并。用并查集+堆即可实现。
时间复杂度:
# include <bits/stdc++.h>
# define N 500010
# define ll long long
using namespace std;
int read(){
int tmp=0, fh=1; char ch=getchar();
while (ch<'0'||ch>'9'){if (ch=='-') fh=-1; ch=getchar();}
while (ch>='0'&&ch<='9'){tmp=tmp*10+ch-'0'; ch=getchar();}
return tmp*fh;
}
vector <int> p[N];
struct node{
int data,next;
}e[N*2];
int cnt[N],low[N],dfn[N],head[N],place,use[N],ti,a[N],w[N],n,fa[N],f[N],hp[N],size,id[N];
ll sum[N],num[N],ans[N];
bool flag;
void build(int u, int v){
e[++place].data=v; e[place].next=head[u];
head[u]=place; cnt[v]++;
}
void tarjan(int x){
use[x]=1; low[x]=dfn[x]=++ti;
for (int ed=head[x]; ed!=0; ed=e[ed].next){
if (use[e[ed].data]==2) continue;
if (use[e[ed].data]==1){
flag=false; return;
low[x]=min(low[x],dfn[e[ed].data]);
}
else{
tarjan(e[ed].data);
low[x]=min(low[x],low[e[ed].data]);
if (flag==false) return;
}
}
}
void dfs(int x){
for (int ed=head[x]; ed!=0; ed=e[ed].next){
dfs(e[ed].data);
fa[e[ed].data]=x;
}
}
int dad(int x){
if (f[x]==x) return x;
else return f[x]=dad(f[x]);
}
bool cmp(int x, int y){
return sum[x]*1.0/num[x]<sum[y]*1.0/num[y];
}
void changeup(int x){
while (x!=1){
if (cmp(hp[x],hp[x/2])==true){
swap(id[hp[x]],id[hp[x/2]]);
swap(hp[x],hp[x/2]);
x=x/2;
}
else return;
}
}
void changedown(int x){
int mn;
while (x*2<=size){
if (x*2+1<=size&&cmp(hp[x*2+1],hp[x*2])==true)
mn=x*2+1; else mn=x*2;
if (cmp(hp[mn],hp[x])==true){
swap(id[hp[mn]],id[hp[x]]);
swap(hp[mn],hp[x]);
x=mn;
}
else return;
}
}
int main(){
n=read();
for (int i=1; i<=n; i++){
a[i]=read();
p[a[i]].push_back(i);
}
for (int i=1; i<=n; i++) w[i]=read();
for (int i=1; i<=n; i++)
for (unsigned j=0; j<p[i].size(); j++)
build(i,p[i][j]);
for (int i=1; i<=n; i++)
if (cnt[i]==0) build(0,i);
flag=true;
tarjan(0);
if (flag==false||ti!=n+1){
printf("-1
");
return 0;
}
memset(use,0,sizeof(use));
dfs(0);
for (int i=1; i<=n; i++){
num[i]=1, sum[i]=w[i];
ans[i]=0;
f[i]=i; id[i]=i;
hp[++size]=i;
changeup(i);
}
num[0]=1;
for (int i=1; i<=n; i++){
int now=hp[1], an=dad(fa[now]);
swap(id[hp[1]],id[hp[size]]);
swap(hp[1],hp[size--]);
changedown(1);
ans[now]=ans[now]+sum[now]*num[an];
ans[an]=ans[an]+ans[now];
num[an]=num[an]+num[now];
sum[an]=sum[an]+sum[now];
f[now]=an;
if (an!=0) changeup(id[an]);
}
printf("%lld
",ans[0]);
return 0;
}