首先分析操作的实质,其实它相当于把某个点连同子树插入到父亲的位置,并将父亲删除,且其余不变。那么
f
f
f值的暴力求法则可以不断往父亲上跳,当父亲的子树除去自己的子树外还有黑点时
f
+
1
f+1
f+1。
如果暴力维护这个过程,复杂度是
O
(
n
3
)
O(n^3)
O(n3),稍微优化一下可以到
O
(
n
2
)
O(n^2)
O(n2),这都不难想到,但是仍旧无法通过。
依次考虑每个点
i
i
i新加入后的贡献,贡献分两种:
1、
i
i
i到根节点路径上的点对
f
i
f_i
fi的贡献;
2、
i
i
i的出现对其它若干个
f
x
f_x
fx的贡献。
发现可以两棵用线段树来维护,一棵维护出现的点的个数(即每次新加入点则在
d
f
n
i
dfn_i
dfni的位置
+
1
+1
+1),一棵维护当前点的加入对还未出现的点的贡献。这两棵线段树都是以DFS序为下标的。
贡献一则直接查询第二棵线段树上对应位置的答案。
贡献二要先预处理每个点
i
i
i父亲的子树中除去
i
i
i子树外的编号最小值
x
x
x,则该最小值出现后能对
i
i
i的子树中所有点产生贡献
+
1
+1
+1的影响。当加入
x
x
x之后,先在所有它对应的
i
i
i子树中查询已出现的点的个数加入答案, 然后在
i
i
i子树中所有点的贡献
+
1
+1
+1。
代码
#include<cstdio>#include<cstring>#include<algorithm>usingnamespace std;#define N 800010#define ll long long#define md 1000000007int last[N], nxt[N], to[N], len =0;int last1[N], nxt1[N], to1[N], len1 =0;int fa[N], dfn[N], pr[N], si[N];int d[N], g[N *4][2];struct{int s[2], mi;}f[N *4];intread(){int s =0;char x =getchar();while(x <'0'|| x >'9') x =getchar();while(x >='0'&& x <='9') s = s *10+ x -48, x =getchar();return s;}voidadd(int x,int y){
to[++len]= y;
nxt[len]= last[x];
last[x]= len;}voidadd1(int x,int y){
to1[++len1]= y;
nxt1[len1]= last1[x];
last1[x]= len1;}voiddfs(int k){
dfn[k]=++dfn[0];
si[k]=1;for(int i = last[k]; i; i = nxt[i])dfs(to[i]), si[k]+= si[to[i]];}voidmake(int v,int l,int r){if(l == r) f[v].mi = pr[l];else{int mid =(l + r)/2;make(v *2, l, mid),make(v *2+1, mid +1, r);
f[v].mi =min(f[v *2].mi, f[v *2+1].mi);}}voidins(int v,int l,int r,int x,int y,int c){if(l == x && r == y){
f[v].s[c]+= r - l +1;
g[v][c]++;}else{int mid =(l + r)/2;
f[v *2].s[c]+= g[v][c]*(mid - l +1), f[v *2+1].s[c]+= g[v][c]*(r - mid);
g[v *2][c]+= g[v][c], g[v *2+1][c]+= g[v][c];
g[v][c]=0;if(y <= mid)ins(v *2, l, mid, x, y, c);elseif(x > mid)ins(v *2+1, mid +1, r, x, y, c);elseins(v *2, l, mid, x, mid, c),ins(v *2+1, mid +1, r, mid +1, y, c);
f[v].s[c]= f[v *2].s[c]+ f[v *2+1].s[c];}}intfs(int v,int l,int r,int x,int y,int c){if(l == x && r == y)return f[v].s[c];int mid =(l + r)/2;
f[v *2].s[c]+= g[v][c]*(mid - l +1), f[v *2+1].s[c]+= g[v][c]*(r - mid);
g[v *2][c]+= g[v][c], g[v *2+1][c]+= g[v][c];
g[v][c]=0;int ans;if(y <= mid) ans =fs(v *2, l, mid, x, y, c);elseif(x > mid) ans =fs(v *2+1, mid +1, r, x, y, c);else ans =fs(v *2, l, mid, x, mid, c)+fs(v *2+1, mid +1, r, mid +1, y, c);
f[v].s[c]= f[v *2].s[c]+ f[v *2+1].s[c];return ans;}intfi(int v,int l,int r,int x,int y){if(x > y)return N;if(l == x && r == y)return f[v].mi;int mid =(l + r)/2;if(y <= mid)returnfi(v *2, l, mid, x, y);if(x > mid)returnfi(v *2+1, mid +1, r, x, y);returnmin(fi(v *2, l, mid, x, mid),fi(v *2+1, mid +1, r, mid +1, y));}intmain(){int n =read(), i, j, rt;for(i =1; i <= n; i++){
fa[i]=read();if(fa[i])add(fa[i], i);else rt = i;}dfs(rt);for(i =1; i <= n; i++) pr[dfn[i]]= i;make(1,1, n);for(i =1; i <= n; i++)if(i != rt){
d[i]=min(fi(1,1, n, dfn[fa[i]]+1, dfn[i]-1),fi(1,1, n, dfn[i]+ si[i], dfn[fa[i]]+ si[fa[i]]-1));
d[i]=min(d[i], fa[i]);add1(d[i], i);}
ll sum =0, ans =1;for(i =1; i <= n; i++){
sum =(sum +1+fs(1,1, n, dfn[i], dfn[i],1))% md;for(j = last1[i]; j; j = nxt1[j]){int x = to1[j];
sum =(sum +fs(1,1, n, dfn[x], dfn[x]+ si[x]-1,0))% md;ins(1,1, n, dfn[x], dfn[x]+ si[x]-1,1);}ins(1,1, n, dfn[i], dfn[i],0);
ans = ans * sum % md;}printf("%d
", ans);return0;}