题目大意
给出一棵树,其中(1)为根。之后每个点向父亲的父亲再连一条边,求得到的图中,每个点走到(1)的期望步数(等概率向相邻点走去)。
保证(i)的父亲(fa_i<i)。
(nleq 2000)
Solution
首先列方程,设(f_i)表示(i)走向(1)的期望步数,有:
[f_x=1+frac{1}{d_x}sum f_y
]
其中,(d_x)是(x)的度数,(y)是与(x)相邻的所有点。
直接高斯消元解方程,复杂度(O(n^3)),过不了。
观察这个系数矩阵的特点,第(x)行的系数只会在父亲,父亲的父亲,儿子,儿子的儿子处有值。如果我们从儿子往根消元,每次用第(x)行消去(fa_x)行和(fa_{fa_x})行的第(x)列,那么最后会得到一个下三角矩阵。这时第(x)行的系数只会在父亲,父亲的父亲处有值,我们从根往儿子消元,就能得到对角线矩阵了。
复杂度(O(n^2))。这个做法利用了系数矩阵的特点,减少消元次数,真是妙不可言~~~
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 2007, P = 998244353;
int n, fa[N], d[N], c[N][N], b[N];
int tot, st[N], nx[N << 2], to[N << 2];
void add(int u, int v) {
to[++tot] = v, nx[tot] = st[u], st[u] = tot;
to[++tot] = u, nx[tot] = st[v], st[v] = tot;
++d[u], ++d[v];
}
int pow(int a, int b) {
int ret = 1;
for (; b; a = 1ll * a * a % P, b >>= 1) if (b & 1) ret = 1ll * ret * a % P;
return ret;
}
int main() {
//freopen("in", "r", stdin);
freopen("b.in", "r", stdin);
freopen("b.out", "w", stdout);
scanf("%d", &n);
for (int i = 2; i <= n; ++i) scanf("%d", &fa[i]), add(i, fa[i]);
for (int i = 1; i <= n; ++i) if (fa[fa[i]]) add(i, fa[fa[i]]);
c[1][1] = 1;
for (int i = 2; i <= n; ++i) {
c[i][i] = 1;
for (int j = st[i]; j; j = nx[j]) c[i][to[j]] = P - pow(d[i], P - 2);
b[i] = 1;
}
for (int i = n; i >= 1; --i) {
if (fa[i]) {
int j = fa[i], res = 1ll * c[j][i] * pow(c[i][i], P - 2) % P;
for (int k = 1; k <= n; ++k) c[j][k] = (c[j][k] - 1ll * c[i][k] * res % P + P) % P;
b[j] = (b[j] - 1ll * b[i] * res % P + P) % P;
}
if (fa[fa[i]]) {
int j = fa[fa[i]], res = 1ll * c[j][i] * pow(c[i][i], P - 2) % P;
for (int k = 1; k <= n; ++k) c[j][k] = (c[j][k] - 1ll * c[i][k] * res % P + P) % P;
b[j] = (b[j] - 1ll * b[i] * res % P + P) % P;
}
}
for (int i = 1; i <= n; ++i) {
for (int l = st[i]; l; l = nx[l]) if (to[l] != fa[i] && to[l] != fa[fa[i]]) {
int j = to[l], res = 1ll * c[j][i] * pow(c[i][i], P - 2) % P;
for (int k = 1; k <= n; ++k) c[j][k] = (c[j][k] - 1ll * c[i][k] * res % P + P) % P;
b[j] = (b[j] - 1ll * b[i] * res % P + P) % P;
}
}
for (int i = 1; i <= n; ++i) printf("%d
", 1ll * b[i] * pow(c[i][i], P - 2) % P);
return 0;
}