Random Access Iterator
[Time Limit: 4000 ms quad Memory Limit: 262144 kB
]
题意
给出伪代码,问按着伪代码在树上跑,能够正确求出来树的深度的概率。
思路
先在树上 (dfs) 一遍,求出每个点可以走到的最深深度,用 (deep) 表示。
令 (dp[i]) 表示从 (i) 节点能够正确求出最深深度的概率。
- 如果 (deep[u]) 不等于最深深度,那么明显 (dp[u]=0)
- 如果 (deep[u]) 是最深深度而且走到了叶子节点,那么 (dp[u]=1)
- 其他情况,能够走到最深深度的概率 = (1) - 每次走每个点都不能走到最深深度的概率,用表达式表达就是,(sz) 表示当前节点可以走的其他节点数,特别注意 (u=1) 的情况就可以了。
[ dp[u] = 1-left(frac{sum_{v}(1-dp[v)}{sz-1}
ight)^{sz-1}
]
/***************************************************************
> File Name : a.cpp
> Author : Jiaaaaaaaqi
> Created Time : Wed 11 Sep 2019 02:48:22 PM CST
***************************************************************/
#include <map>
#include <set>
#include <list>
#include <ctime>
#include <cmath>
#include <stack>
#include <queue>
#include <cfloat>
#include <string>
#include <vector>
#include <cstdio>
#include <bitset>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <unordered_map>
#define lowbit(x) x & (-x)
#define mes(a, b) memset(a, b, sizeof a)
#define fi first
#define se second
#define pb push_back
#define pii pair<int, int>
typedef unsigned long long int ull;
typedef long long int ll;
const int maxn = 1e6 + 10;
const int maxm = 1e5 + 10;
const ll mod = 1e9 + 7;
const ll INF = 1e18 + 100;
const int inf = 0x3f3f3f3f;
const double pi = acos(-1.0);
const double eps = 1e-8;
using namespace std;
int n, m;
int cas, tol, T;
int maxdeep;
vector<int> vv[maxn];
int deep[maxn];
ll dp[maxn];
ll fpow(ll a, ll b) {
ll ans = 1;
while(b) {
if(b&1) ans = ans*a%mod;
a = a*a%mod;
b >>= 1;
}
return ans;
}
void dfs(int u, int fa, int d) {
deep[u] = d;
for(auto v : vv[u]) {
if(v == fa) continue;
dfs(v, u, d+1);
deep[u] = max(deep[u], deep[v]);
}
maxdeep = max(maxdeep, deep[u]);
}
void solve(int u, int fa) {
if(deep[u] != maxdeep) {
dp[u] = 0;
return ;
}
if(vv[u].size() == 1) {
dp[u] = 1;
return ;
}
int sz = vv[u].size()-1;
if(u == 1) sz++;
ll ans = 0;
for(auto v : vv[u]) {
if(v == fa) continue;
solve(v, u);
ans += (1-dp[v]+mod);
ans %= mod;
}
ans = ans*fpow(sz, mod-2)%mod;
ans = fpow(ans, sz);
dp[u] = (1-ans+mod)%mod;
}
int main() {
// freopen("in", "r", stdin);
scanf("%d", &n);
for(int i=1; i<=n; i++) {
vv[i].clear();
}
for(int i=1; i<n; i++) {
int u, v;
scanf("%d%d", &u, &v);
vv[u].pb(v);
vv[v].pb(u);
}
maxdeep = 0;
dfs(1, 0, 1);
solve(1, 0);
printf("%lld
", dp[1]);
return 0;
}