题面
题解
考虑非根结点 (u) 和它的父亲 (fa_u) ,若 (fa_u) 除了 (u) 所在的子树的其他子树中有其他黑点 (v) ,那么称 (fa_u) 被 (v) 覆盖,并且 (fa_u) 会对以 (u) 为根的子树中的黑色节点产生贡献。对于每一个结点 (u) ,找出最早覆盖 (fa_u) 的结点 (v) ,统计出 (n-1) 个点对 ((u,v)) 。
从小到大加入黑点,查找当前加入的结点 (v) 覆盖了哪些结点,以及对应的点对 ((u,v)) ,对 (u) 的子树贡献加一,子树加用树状数组维护即可。处理完这些点对后,在树状数组上查找 (f_v) ,计入答案中。因为还要知道 (v) 对之前几个黑点产生了贡献,所以还要用一个树状数组统计子树内黑点个数。
( ext{Code}:)
#include <cctype>
#include <cstring>
#include <cstdio>
#include <algorithm>
#define INF 0x3f3f3f3f
using namespace std;
typedef long long lxl;
const int maxn=8e5+5;
const lxl mod=1e9+7;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
char buf[1<<21],*p1=buf,*p2=buf;
template <typename T>
inline void read(T &x)
{
x=0;T f=1;char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
x*=f;
}
struct edge
{
int u,v,next;
edge(int u,int v,int next):
u(u),v(v),next(next){}
edge(){}
}e[maxn];
int head[maxn],ecnt;
inline void add(int u,int v)
{
e[ecnt]=edge(u,v,head[u]);
head[u]=ecnt++;
}
int n,fa[maxn];
int dfn[maxn],idx[maxn],siz[maxn],dfs_cnt;
int Min[maxn],sMin[maxn],son[maxn];
pair<int,int> pii[maxn];
int pcnt;
void dfs(int u)
{
dfn[u]=++dfs_cnt;
idx[dfs_cnt]=u;
siz[u]=1;
Min[u]=u;
for(int i=head[u];~i;i=e[i].next)
{
int v=e[i].v;
dfs(v);
siz[u]+=siz[v];
if(Min[v]<Min[u]) sMin[u]=Min[u],Min[u]=Min[v],son[u]=v;
else if(Min[v]<sMin[u]) sMin[u]=Min[v];
}
}
namespace BIT
{
int sum[maxn];
inline int lowbit(int x) {return x&-x;}
inline void add(int x,int d)
{
for(int i=x;i<=n;i+=lowbit(i))
sum[i]+=d;
}
inline int query(int x)
{
int res=0;
for(int i=x;i>=1;i-=lowbit(i))
res+=sum[i];
return res;
}
inline int query(int l,int r)
{
return query(r)-query(l-1);
}
}
namespace Segment_Tree
{
int sum[maxn];
inline int lowbit(int x) {return x&-x;}
inline void add(int x,int d)
{
for(int i=x;i<=n;i+=lowbit(i))
sum[i]+=d;
}
inline int query(int x)
{
int res=0;
for(int i=x;i>=1;i-=lowbit(i))
res+=sum[i];
return res;
}
inline void modify(int l,int r,int d)
{
add(l,d);add(r+1,-d);
}
}
int main()
{
freopen("dierti.in","r",stdin);
freopen("dierti.out","w",stdout);
read(n);
memset(head,-1,sizeof(head));
int rt;
for(int i=1;i<=n;++i)
{
read(fa[i]);
if(fa[i]) add(fa[i],i);
else rt=i;
}
dfs(rt);
for(int i=1;i<=n;++i) if(fa[i])
pii[++pcnt]=make_pair(son[fa[i]]==i?sMin[fa[i]]:Min[fa[i]],i);
sort(pii+1,pii+pcnt+1);
int ans=1,sum=0;
for(int i=1,p=1;i<=n;++i)
{
while(p<=pcnt&&pii[p].first<=i)
{
int u=pii[p].second;
Segment_Tree::modify(dfn[u],dfn[u]+siz[u]-1,1);
sum+=BIT::query(dfn[u],dfn[u]+siz[u]-1);
if(sum>=mod) sum-=mod;
++p;
}
BIT::add(dfn[i],1);
sum+=Segment_Tree::query(dfn[i])+1;
if(sum>=mod) sum-=mod;
ans=1ll*ans*sum%mod;
}
printf("%d
",ans);
return 0;
}