测试地址:排列
做法:本题需要用到贪心+堆+并查集。
观察题目中的条件,实际上就是指不能出现在之前,也就是说,我们可以把问题看成选点,如果点是第个选的,会获得的收益,并且不能比先取,求最大的收益。
如果没有限制条件,根据排序不等式,先取小的肯定是最优的,可是考虑限制后怎么办呢?
首先,如果约束条件形成了一个环,显然不可能有合法的排列,否则约束条件就形成以为根的树,其中为的父亲。我们有以下结论:
考虑当前权值最小的点,那么点的父亲一旦被选,接着必须马上选点才能最优。
至于证明……可以去看网上各位大佬的证明,我只会感性理解……
以上结论表示,我们可以把这样的和绑在一起,合并成一个新的节点。那么新的节点的权值应该怎么确定呢?考虑两个序列拼起来能得到的新的贡献,令第一个序列长为,元素和为,第二个序列长为,元素和为,那么把第一个序列拼到第二个序列后面可以得到的收益,把第二个序列拼到第一个序列后面可以得到的收益,那么使得先取第一个序列更优的条件为,也就是。那么我们就可以用这样的平均权值来作为新节点的权值,那么只要用堆和并查集维护这个贪心即可。
要注意的是,因为这里的堆要支持定点删除,所以STL中的priority_queue可能不能用,而用set貌似会被卡常,所以最好的方法是用手写堆(毕竟考试好像不能用pbds……),时间复杂度为。
(话说这题应该是原题,原题题号是POJ2054,在省选前校内胡测就见过,可惜不会,没想到居然真的有省选会拿原题来出啊……)
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,a[500010],hp[500010],hpsiz=0,pos[500010];
int fa[500010],vis[500010]={0};
ll w[500010],sum[500010],siz[500010],ans=0;
bool cmp(int a,int b)
{
return sum[a]*siz[b]<sum[b]*siz[a];
}
bool check()
{
int tim=0,x;
for(int i=1;i<=n;i++)
if (!vis[i])
{
vis[i]=++tim;
x=i;
while(!vis[a[x]])
{
vis[a[x]]=tim;
x=a[x];
}
if (a[x]&&vis[a[x]]==tim) return 1;
}
return 0;
}
void maintaindown(int v)
{
int mn=v;
if ((v<<1)<=hpsiz&&cmp(hp[v<<1],hp[mn]))
mn=v<<1;
if ((v<<1|1)<=hpsiz&&cmp(hp[v<<1|1],hp[mn]))
mn=v<<1|1;
if (mn==v) return;
swap(hp[v],hp[mn]);
pos[hp[v]]=v;
pos[hp[mn]]=mn;
maintaindown(mn);
}
void maintainup(int v)
{
if (v==1) return;
if (cmp(hp[v],hp[v>>1]))
{
swap(hp[v>>1],hp[v]);
pos[hp[v>>1]]=v>>1;
pos[hp[v]]=v;
maintainup(v>>1);
}
}
void insert(int x)
{
hp[++hpsiz]=x;
pos[x]=hpsiz;
maintainup(hpsiz);
}
void Delete(int x)
{
int v=hp[hpsiz--];
pos[v]=x;
if (cmp(hp[x],v)) hp[x]=v,maintaindown(x);
else if (cmp(v,hp[x])) hp[x]=v,maintainup(x);
else hp[x]=v;
}
int find(int x)
{
int r=x,i=x,j;
while(r!=fa[r]) r=fa[r];
while(i!=r) j=fa[i],fa[i]=r,i=j;
return r;
}
void merge(int x,int y)
{
int fx=find(x),fy=find(y);
fa[fx]=fy;
sum[fy]+=sum[fx];
siz[fy]+=siz[fx];
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(int i=1;i<=n;i++)
scanf("%lld",&w[i]);
if (check()) {printf("-1");return 0;}
fa[0]=sum[0]=0;
siz[0]=1;
for(int i=1;i<=n;i++)
{
fa[i]=i;
sum[i]=w[i];
siz[i]=1;
insert(i);
}
bool flag=0;
while(hpsiz>=1)
{
int v=hp[1];Delete(1);
int f=find(a[v]);
if (f) Delete(pos[f]);
ans+=sum[v]*siz[f];
merge(v,f);
if (f) insert(f);
}
printf("%lld",ans);
return 0;
}