分析:比较难的一道题,看到要求方案数,又是在一棵树上,自然就想起了树形dp.状态该怎么表示呢?首先肯定有一维状态表示以i为根的子树,考虑到i有没有匹配对答案也是有影响的,自然而然状态就出来了:f[i][0/1]表示以i为根的子树中,i取或不取的最大匹配.因为要求方案数,再开一个数组g[i][0/1]记录方案数.
接下来考虑怎么转移.如果i不选,那么它的子节点不管选不选都没关系:
f[i][0] = Σmax{f[j][0],f[j][1]},如果i要选,那么它的子节点中一定有一个没选的点k,其它点随意.
f[i][1] = max{f[k][0] + 1 + Σmax{f[j][0],f[j][1] (j != k)}}.对于g的转移就是比较常见的套路了,子树内加法原理,子树合并乘法原理.每次选肯定要选最大匹配的那种方案,所以开一个数组ans,如果f[i][0] > f[i][1],ans[i] = g[i][0]; f[i][0] < f[i][1],ans[i] = g[i][1];
f[i][0] = f[i][1],ans[i] = g[i][1] + g[i][0]. 那么g[i][0] = πans[j],i要选的话,k就不能直接用ans的值,那么g[i][1] = Σ(g[k][0] * (πans[j] / ans[k])),涉及到取模,用到了乘法逆元.
设计状态的时候想清楚当前点有哪几种状态,它们对答案有没有影响.转移的时候想想转移过来的子节点必须满足什么要求,其它的点该怎么分配.求方案数的时候要分清楚是乘法原理还是加法原理.
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const int maxn = 100010, mod = 1000000007; typedef long long ll; int T, P, n, head[maxn], to[maxn * 2], nextt[maxn * 2], tot = 1; ll f[maxn][2], g[maxn][2], ans[maxn]; void add(int x, int y) { to[tot] = y; nextt[tot] = head[x]; head[x] = tot++; } ll qpow(ll a, ll b) { ll res = 1; while (b) { if (b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1; } return res; } void dfs(int x, int fa) { f[x][0] = f[x][1] = ans[x] = 0; g[x][0] = g[x][1] = 1; bool flag = false; ll sum = 0, mul = 1; for (int i = head[x]; i; i = nextt[i]) { int v = to[i]; if (v != fa) { dfs(v, x); sum += max(f[v][0], f[v][1]); sum %= mod; f[x][0] += max(f[v][0], f[v][1]); f[x][0] %= mod; g[x][0] = g[x][0] * ans[v] % mod; mul = mul * ans[v] % mod; flag = true; } } if (!flag) g[x][1] = 0; for (int i = head[x]; i; i = nextt[i]) { int v = to[i]; if (v != fa) { if (f[x][1] < f[v][0] + 1 + sum - max(f[v][1], f[v][0])) { f[x][1] = (((f[v][0] + 1 + sum) % mod) - max(f[v][1], f[v][0]) + mod) % mod; g[x][1] = g[v][0] * mul % mod * qpow(ans[v],mod - 2) % mod; } else if (f[x][1] == f[v][0] + 1 + sum - max(f[v][1], f[v][0])) g[x][1] = (g[x][1] + g[v][0] * mul % mod * qpow(ans[v],mod - 2) % mod) % mod; } } if (f[x][0] > f[x][1]) ans[x] = g[x][0]; if (f[x][0] < f[x][1]) ans[x] = g[x][1]; if (f[x][0] == f[x][1]) ans[x] = g[x][0] + g[x][1]; ans[x] %= mod; } int main() { scanf("%d%d", &T, &P); while (T--) { memset(head, 0, sizeof(head)); tot = 1; scanf("%d", &n); for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); add(u, v); add(v, u); } dfs(1, 0); printf("%lld ", max(f[1][0], f[1][1])); if (P == 2) printf("%lld", ans[1] % mod); printf(" "); } return 0; }