好强的题。
方案不好算,改成算概率,注意因为是模意义下的概率所以直接乘法逆元就好不要傻傻地开double。
设$f[i][d][0]$为第i个节点离d层的球球走到第i个点时第i个点没有球的概率, $f[i][d][1]$为有1个球的概率, $f[i][d][2]$为有2个球及以上的概率。
我们可以把$f[i]$看成一个队列, 然后从儿子转移的时候, 就是把儿子的队列一个一个合并起来,最后在队列头加上一个$f[i][0]$, 并且把队列里的所有$f[i][0$~$d][2]$加上$f[i][0$~$d][0]$,并且$f[i][0$~$d][2]$变成0就好了。
合并的时候转移为:
$f[i][d][0]=f[i][d][0]*f[j][d][0]$
$f[i][d][1]=f[i][d][1]*f[j][d][0]+f[i][d][0]*f[j][d][1]$
$f[i][d][2]=f[i][d][0]*f[j][d][2]+f[i][d][1]*f[j][d][2]+f[i][d][1]*f[j][d][1]+f[i][d][2]*f[j][d][2]+f[i][d][2]*f[j][d][1]+f[i][d][2]*f[j][d][0]$
复杂度为O(N),因为每层元素只加1,交集最多为N。

#include<iostream> #include<cstring> #include<cstdlib> #include<cstdio> #include<vector> #define ll long long #define MOD(x) ((x)>=mod?(x)-mod:(x)) using namespace std; const int maxn=500010, mod=1e9+7; struct tjm{int too, pre;}e[maxn<<1]; struct poi{int f[3];}; int n, x, ans, tot, tott; int last[maxn], root[maxn]; vector<poi>q[maxn]; inline void read(int &k) { int f=1; k=0; char c=getchar(); while(c<'0' || c>'9') c=='-' && (f=-1), c=getchar(); while(c<='9' && c>='0') k=k*10+c-'0', c=getchar(); k*=f; } inline void add(int x, int y){e[++tot]=(tjm){y, last[x]}; last[x]=tot;} inline int merge(int x, int y) { if(q[x].size()<q[y].size()) swap(x, y); int nx=q[x].size()-1, ny=q[y].size()-1; for(int i=0;i<=ny;i++) { int sum0=0, sum1=0, sum2=0; sum0=1ll*q[x][nx-i].f[0]*q[y][ny-i].f[0]%mod; sum1=(1ll*q[x][nx-i].f[1]*q[y][ny-i].f[0]+1ll*q[x][nx-i].f[0]*q[y][ny-i].f[1])%mod; for(int j=0;j<3;j++) for(int k=2;j+k>=2;k--) sum2=(1ll*sum2+1ll*q[x][nx-i].f[j]*q[y][ny-i].f[k])%mod; q[x][nx-i].f[0]=sum0; q[x][nx-i].f[1]=sum1; q[x][nx-i].f[2]=sum2; } q[y].clear(); return x; } void dfs(int x, int fa) { if(!last[x]) root[x]=++tott; int dep=0; for(int i=last[x], too;i;i=e[i].pre) if((too=e[i].too)!=fa) { dfs(too, x); if(!root[x]) root[x]=root[too]; else dep=max(dep, (int)min(q[root[x]].size(), q[root[too]].size())), root[x]=merge(root[x], root[too]); } int nx=q[root[x]].size()-1; for(int i=0;i<dep;i++) q[root[x]][nx-i].f[0]=MOD(q[root[x]][nx-i].f[0]+q[root[x]][nx-i].f[2]), q[root[x]][nx-i].f[2]=0; poi tmp; tmp.f[1]=tmp.f[0]=(mod+1)>>1; tmp.f[2]=0; q[root[x]].push_back(tmp); } inline int power(int a, int b) { int ans=1; for(;b;b>>=1, a=1ll*a*a%mod) if(b&1) ans=1ll*ans*a%mod; return ans; } int main() { read(n); for(int i=1;i<=n;i++) read(x), add(x, i); dfs(0, -1); for(int i=0;i<q[root[0]].size();i++) ans=MOD(ans+q[root[0]][i].f[1]); printf("%lld ", 1ll*ans*power(2, n+1)%mod); }