思路
我们可以把到每个点的期望步数算出来取max?但是直接算显然是不行的
那就可以用Min-Max来容斥一下
设(g_{s})是从x到s中任意一个点的最小步数
设(f_{s})是从x到s中任意一个点的最大步数
然后就可以的得到
(f_{s}=sum_{tsubseteq s}(-1)^{|t|+1}g_t)
然后考虑g怎么求
设(p_i)是i点到任意一个子集中的点的最小步数
有(p_u=frac{1}{du_u}(1+p_{fa_u})+frac{1}{du_u}sum_{vin child_u}(p_v+1))
然后我们令(p_u=a_up_{fa_u}+b_u)
很显然有(p_u=frac{1}{du_u}sum(a_vf_u+b_v+1)+frac{1}{du_u}(p_{fa_u}))
然后移项可以得到(a_u=frac{1}{du_u-sum a_v},b_u=frac{sum(b_v+1)+1}{du_u-sum a_v})
然后因为x是根没有父亲,所以(g_{s}=(bitcnt(s) & 1)?b_u:-b_u)
然后就可以用子集前缀和进行累加了
最后直接输出答案就可以了
#include <bits/stdc++.h>
using namespace std;
const int Mod = 998244353;
const int N = 20;
int n, m, x;
int a[N], b[N], du[N];
int f[1 << N];
vector<int> g[N];
int main() {
#ifdef dream_maker
freopen("input.txt", "r", stdin);
#endif
function<int(int a, int b)> add = [&](int a, int b) {
return (a += b) >= Mod ? a - Mod : a;
};
function<int(int a, int b)> sub = [&](int a, int b) {
return (a -= b) < 0 ? a + Mod : a;
};
function<int(int a, int b)> mul = [&](int a, int b) {
return (long long) a * b % Mod;
};
function<int(int a, int b)> fast_pow = [&](int a, int b) {
int res = 1;
for (; b; b >>= 1, a = mul(a, a))
if (b & 1) res = mul(res, a);
return res;
};
function<int(int a)> bitcnt = [&](int a) {
int res = 0;
for (; a; a >>= 1)
if (a & 1) ++res;
return res;
};
function<void(int u, int fa, int s)> dfs = [&](int u, int fa, int s) {
if ((s >> (u - 1)) & 1) return;
a[u] = du[u], b[u] = (u == x) ? 0 : 1; // x不用向fa走的1
for (auto v : g[u]) {
if (v == fa) continue;
dfs(v, u, s);
a[u] = sub(a[u], a[v]);
b[u] = add(b[u], b[v] + 1);
}
a[u] = fast_pow(a[u], Mod - 2);
b[u] = mul(b[u], a[u]);
};
scanf("%d %d %d", &n, &m, &x);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
++du[u], ++du[v];
}
int up = (1 << n) - 1;
for (int s = 1; s <= up; s++) {
for (int i = 1; i <= n; i++)
a[i] = b[i] = 0;
dfs(x, 0, s);
f[s] = (bitcnt(s) & 1) ? b[x] : (Mod - b[x]) % Mod;
}
f[0] = 0;
for (int i = 1; i <= n; i++) { // 这个循环在外面
for (int s = 1; s <= up; s++) {
if ((s >> (i - 1)) & 1) {
f[s] = add(f[s], f[s ^ (1 << (i - 1))]);
}
}
}
while (m--) {
int num, cur, s = 0;
scanf("%d", &num);
while (num--) {
scanf("%d", &cur);
s |= 1 << (cur - 1);
}
printf("%d
", f[s]);
}
return 0;
}