题目描述
桑尼、露娜和斯塔在玩点对游戏,这个游戏在一棵节点数为(n)的树上进行。桑尼、露娜和斯塔三人轮流从树上所有未被占有的节点中选取一点,归为己有,轮流顺序为桑尼、露娜、斯塔、桑尼、露娜……。该选取过程直到树上所有点都被选取后结束。
选完点后便可计算每人的得分。点对游戏中有(m)个幸运数,在某人占据的节点中,每有一对点的距离为某个幸运数,就得到一分。(树上两点之间的距离定义为两点之间的简单路径的边数)。
你的任务是,假设桑尼、露娜和斯塔每次选取时,都是从未被占有的节点中等概率选取一点,计算每人的期望得分。
题解
首先可以知道每个人可以选择的点数是固定的,可以很容易求得。
那么考虑,假设当前的人可以选(k)个点,那么他有(C_n^k)种选择方式。
假设一对点的距离满足题目要求,那么他会做出(C_{n-2}^{k-2})次贡献,所以最后的期望就是(cnt*frac{C_{n-2}^{k-2}}{C_n^k}),其中(cnt)表示满足题目要求的点对数。
那么有多少符合条件的点对数呢,我么可以通过点分治求出。
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 5e4 + 5;
int n, m, is[N], tot, head[N], mx, cnt, maxx[N], siz[N], rt, vis[N], sum, c1[N], c2[N], mxdep[N], val[15];
struct node{int to, nex;}a[N << 1];
inline int read()
{
int x = 0, f = 1; char ch = getchar();
while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') {x = (x << 3) + (x << 1) + (ch ^ 48); ch = getchar();}
return x * f;
}
void add(int x, int y) {a[++ tot].to = y; a[tot].nex = head[x]; head[x] = tot;}
void dfs(int x, int fa, int dis, int now)
{
if(is[dis] && x >= now) cnt ++; if(dis >= mx) return;
for(int i = head[x]; i; i = a[i].nex)
{
int y = a[i].to;
if(y == fa) continue;
dfs(y, x, dis + 1, now);
}
}
void work()
{
for(int i = 1; i <= n; i ++) dfs(i, 0, 0, i);
double k; k = (n / 3) + (n % 3 != 0);
printf("%.2f
", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
k = (n / 3) + (n % 3 == 2); printf("%.2f
", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
k = (n / 3); printf("%.2f
", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
}
void get_root(int x, int fa)
{
siz[x] = 1; maxx[x] = 0;
for(int i = head[x]; i; i = a[i].nex)
{
int y = a[i].to;
if(y == fa || vis[y]) continue;
get_root(y, x); siz[x] += siz[y];
maxx[x] = max(maxx[x], siz[y]);
}
maxx[x] = max(maxx[x], sum - siz[x]);
if(maxx[x] < maxx[rt]) rt = x;
}
void get_ans(int x, int fa, int dis)
{
c2[dis] ++; mxdep[x] = dis;
for(int i = 1; i <= m; i ++) if(dis <= val[i]) cnt += c1[val[i] - dis];
for(int i = head[x]; i; i = a[i].nex)
{
int y = a[i].to;
if(y == fa || vis[y]) continue;
get_ans(y, x, dis + 1);
mxdep[x] = max(mxdep[x], mxdep[y]);
}
}
void dfs(int x, int fa)
{
vis[x] = 1; c1[0] = 1; int em = 0;
for(int i = head[x]; i; i = a[i].nex)
{
int y = a[i].to;
if(y == fa || vis[y]) continue;
get_ans(y, x, 1); em = max(em, mxdep[y]);
for(int j = 0; j <= mxdep[y]; j ++) c1[j] += c2[j], c2[j] = 0;
}
for(int j = 0; j <= em; j ++) c1[j] = c2[j] = 0;
for(int i = head[x]; i; i = a[i].nex)
{
int y = a[i].to;
if(y == fa || vis[y]) continue;
sum = siz[y]; rt = 0; get_root(y, 0); dfs(rt, 0);
}
}
void work2()
{
maxx[0] = 0x3f3f3f3f; sum = n; get_root(1, 0); dfs(rt, 0);
double k; k = (n / 3) + (n % 3 != 0);
printf("%.2f
", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
k = (n / 3) + (n % 3 == 2); printf("%.2f
", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
k = (n / 3); printf("%.2f
", (double)cnt * k * (k - 1) / (double)(1ll * n * (n - 1)));
}
int main()
{
n = read(); m = read();
for(int i = 1, x; i <= m; i ++) {x = read(); is[x] = 1; mx = max(mx, x); val[i] = x;}
for(int i = 1, x, y; i <= n - 1; i ++)
{
x = read(); y = read();
add(x, y); add(y, x);
}
if(mx <= 100 || n <= 1000) work(); else work2();
return 0;
}