题目描述
题解
这题考场时想了特别久,花了很多时间,但是只想出了(O(n^2))的做法,满分做法其实不难。
容易发现,一个点如果能变成黑的,当且仅当这个点是黑色或者子树中有两个节点是黑色的。
进一步可以发现,对于这个子树中的点,他们对于这个点的要求是:除自己外的子树有黑色点,或者这个点是黑色。
也就是只要这个点能变成黑色,就能对其子树中所有点产生贡献。
不妨用一个树形dp,将每个点变成黑色的最短时间计算出来。然后按编号从小到大加入点,如果在当前编号下某个点可以变成黑色,就将这个点子树中的点(出这个点以外)对答案的贡献加一,同时将当前最新变色的点对答案的贡献加一。然后就能知道每个点首次加进去的答案了。
但是还需要统计对之前的点的贡献,因为这些点也可能被当前的区间修改,所以再开一个线段树维护一下就行了。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 800010
#define ll long long
#define mo 1000000007
using namespace std;
int n,i,j,bfs[N],fa[N],dfn[N],son[N],num,root,q[N],tot,f[N],g[N],stk[N][2],top;
int x,y;
ll sum,ans;
struct node{
ll lazy,sum,pk;
}tr[N*5];
struct edge{
int to,next;
}e[N];
struct pl{
int time,val;
}dp[N];
int read(){
int x=0;
char ch=getchar();
while (ch<'0'||ch>'9') ch=getchar();
while (ch>='0'&&ch<='9'){
x=(x<<1)+(x<<3)+(ch^48);
ch=getchar();
}
return x;
}
void insert_edge(int x,int y){
tot++;
e[tot].to=y;
e[tot].next=q[x];
q[x]=tot;
}
int cmp(pl x,pl y){
return x.time<y.time;
}
void update(int x,int l,int r){
ll p=tr[x].lazy;
if (l<r){
int mid=(l+r)/2;
tr[x*2].sum=(tr[x*2].sum+(mid-l+1)*p)%mo;
tr[x*2+1].sum=(tr[x*2+1].sum+(r-mid)*p)%mo;
tr[x*2].lazy=(tr[x*2].lazy+p)%mo;tr[x*2+1].lazy=(tr[x*2+1].lazy+p)%mo;
}
tr[x].lazy=0;
}
void change(int x,int l,int r,int l1,int r1){
if (l1>r1) return;
update(x,l,r);
if (l1<=l&&r1>=r){
tr[x].sum=(tr[x].sum+r-l+1)%mo;
tr[x].lazy=tr[x].lazy+1;
return;
}
int mid=(l+r)/2;
if (l1<=mid) change(x*2,l,mid,l1,r1);
if (r1>mid) change(x*2+1,mid+1,r,l1,r1);
tr[x].sum=(tr[x*2].sum+tr[x*2+1].sum)%mo;
}
void changepk(int x,int l,int r,int k){
if (l==r){
tr[x].pk++;
return;
}
int mid=(l+r)/2;
if (k<=mid) changepk(x*2,l,mid,k);
else changepk(x*2+1,mid+1,r,k);
tr[x].pk=(tr[x*2].pk+tr[x*2+1].pk)%mo;
}
ll gets(int x,int l,int r,int k){
update(x,l,r);
if (l==r) return tr[x].sum;
int mid=(l+r)/2;
if (k<=mid) return gets(x*2,l,mid,k);
else return gets(x*2+1,mid+1,r,k);
}
ll getpk(int x,int l,int r,int l1,int r1){
if (l1>r1) return 0;
if (l1<=l&&r1>=r) return tr[x].pk;
int mid=(l+r)/2,pk=0;
if (l1<=mid) pk=(pk+getpk(x*2,l,mid,l1,r1))%mo;
if (r1>mid) pk=(pk+getpk(x*2+1,mid+1,r,l1,r1))%mo;
return pk;
}
int main(){
freopen("dierti.in","r",stdin);
freopen("dierti.out","w",stdout);
n=read();
for (i=1;i<=n;i++){
fa[i]=read();
if (fa[i]==0) root=i;
else insert_edge(fa[i],i);
}
top=1;stk[1][0]=root;stk[1][1]=q[root];
while (top){
x=stk[top][0];
if (stk[top][1]==q[x]){
f[x]=g[x]=1e9;
dfn[x]=++num;
}
if (!stk[top][1]){
g[fa[x]]=min(g[fa[x]],min(f[x],x));
if (g[fa[x]]<f[fa[x]]) swap(f[fa[x]],g[fa[x]]);
son[x]=num;
top--;
continue;
}
for (i=stk[top][1];i;i=e[i].next){
y=e[i].to;
stk[top][1]=e[i].next;
stk[++top][0]=y;stk[top][1]=q[y];
break;
}
}
for (i=1;i<=n;i++) dp[i].time=min(i,g[i]),dp[i].val=i;
sort(dp+1,dp+n+1,cmp);
j=0;
ans=1;
for (i=1;i<=n;i++){
change(1,1,n,dfn[i],dfn[i]);
while (j+1<=n&&dp[j+1].time<=i){
j++;
change(1,1,n,dfn[dp[j].val]+1,son[dp[j].val]);
sum=(sum+getpk(1,1,n,dfn[dp[j].val]+1,son[dp[j].val]))%mo;
}
sum=(sum+gets(1,1,n,dfn[i]))%mo;
changepk(1,1,n,dfn[i]);
ans=ans*sum%mo;
}
printf("%lld
",ans);
return 0;
}