题意
给出一棵(n)个点的树,一个点会等概率地向相邻的点走去,花费(1)的时间,现在给你若干个点(v_0,v_1,...,v_p),求(v_0)->(v_1)的期望时间+(v_1)->(v_2)的期望时间+(v_{p-1})->(v_p)的期望时间。
(nleq 50000),询问总点数(leq 50000)。
Solution
设(d_i)表示(i)的度数,(f_i)表示(i)走向父亲的期望步数,转移两种情况:
- 直接走向父亲
- 走到某个儿子再走回来再走向父亲
得出这样的转移方程:
(f_i=frac{1}{d_i}+frac{1}{d_i}sum_{vin son_i}1+f_v+f_i)
经过简单的化简就能得到:
(f_i=d_i+sum_{vin son_i}f_v)
那么一个dfs就求出了(f)。
设(g_i)表示(i)从父亲走来的期望步数,转移三种情况:
- 直接由父亲走来
- 父亲走到另一个儿子再走回来再走到(i)
- 父亲走到它的父亲再走回来再走到(i)
得出这样的转移方程:
(g_i=frac{1}{d_{fa_i}}+frac{1}{d_{fa_i}}sum_{vin son_{fa_i},v
e i}(1+f_v+g_i)+frac{1}{d_{fa_i}}(1+g_{fa_i}+g_i))
经过简单的化简就能得到:
(g_i=d_{fa_i}+g_{fa_i}+sum_{vin son_{fa_i},v
e i}f_v)
那么一个dfs就求出了(g)。
在树上记录一下(f)和(g)的前缀和(根到某个点的(f),(g)之和),然后加个倍增求lca,这题就顺利解决了。
Code
#include <cstdio>
#include <cstring>
typedef long long ll;
const int N = 50007;
void swap(int &a, int &b) { int t = a; a = b, b = t; }
int T;
int q, p, v[N];
int n, tot, st[N], to[N << 1], nx[N << 1], d[N], dep[N], anc[N][17];
ll ans, f[N], g[N], sumf[N], sumg[N];
void add(int u, int v) { to[++tot] = v, nx[tot] = st[u], st[u] = tot; }
void getf(int u, int from)
{
ll sum = 0;
for (int i = st[u]; i; i = nx[i]) if (to[i] != from) getf(to[i], u), sum += f[to[i]];
f[u] = d[u] + sum;
}
void getg(int u, int from)
{
ll sum = 0;
for (int i = st[u]; i; i = nx[i]) if (to[i] != from) sum += f[to[i]];
for (int i = st[u]; i; i = nx[i]) if (to[i] != from) g[to[i]] = g[u] + d[u] + sum - f[to[i]], getg(to[i], u);
}
void dfs(int u, int from)
{
sumf[u] = sumf[from] + f[u], sumg[u] = sumg[from] + g[u];
for (int i = st[u]; i; i = nx[i]) if (to[i] != from) dep[to[i]] = dep[u] + 1, anc[to[i]][0] = u, dfs(to[i], u);
}
int getlca(int u, int v)
{
if (dep[u] < dep[v]) swap(u, v);
for (int i = 16; i >= 0; i--) if (dep[anc[u][i]] >= dep[v]) u = anc[u][i];
if (u == v) return u;
for (int i = 16; i >= 0; i--) if (anc[u][i] != anc[v][i]) u = anc[u][i], v = anc[v][i];
return anc[u][0];
}
ll getlen(int u, int v)
{
int lca = getlca(u, v);
ll ret = sumf[u] - sumf[lca] + sumg[v] - sumg[lca];
return ret;
}
int main()
{
scanf("%d", &T);
while (T--)
{
tot = 0;
memset(d, 0, sizeof(d)); memset(f, 0, sizeof(f)); memset(g, 0, sizeof(g));
memset(anc, 0, sizeof(anc)); memset(dep, 0, sizeof(dep)); memset(st, 0, sizeof(st)); memset(nx, 0, sizeof(nx));
scanf("%d", &n);
for (int i = 1, u, v; i < n; i++) scanf("%d%d", &u, &v), u++, v++, add(u, v), add(v, u), d[u]++, d[v]++;
getf(1, 0), getg(1, 0), dep[1] = 1, dfs(1, 0);
for (int j = 1; j <= 16; j++) for (int i = 1; i <= n; i++) anc[i][j] = anc[anc[i][j - 1]][j - 1];
scanf("%d", &q);
while (q--)
{
ans = 0;
scanf("%d", &p);
for (int i = 0; i <= p; i++) scanf("%d", &v[i]), v[i]++;
for (int i = 1; i <= p; i++) ans += getlen(v[i - 1], v[i]);
printf("%lld.0000
", ans);
}
printf("
");
}
return 0;
}