分析:显然的,树形dp,状态也很好想到:f[i][j]表示以i为根的子树收集到j个果子的方案数.转移的话就相当于是背包问题,每个子节点可以选或不选.如果不选子节点k的话,那么以k为根的子树的边无论断不断都没关系,贡献就是f[i][j] * 2^(size[k]).如果选的话,枚举一下收集到多少个果子,对答案的贡献就是f[i][j - p] * f[k][p].基本的计数原理.
不过这个转移是O(n^3)的,怎么优化呢?状态定义为这个样子是没法继续优化的,如果把状态的表示改成dfs到第i个点,收集到j个果子的方案数,就能够神奇地做到O(n^2)了.因为dfs是每次先向下递归,然后子节点向上回溯嘛,向下递归的时候就用父节点的状态去更新子节点的状态,向上回溯就用子节点的答案去更新父节点的答案.也就是说:向下走,更新状态;向上走,统计答案.
60分暴力:
#include <cstdio> #include <cmath> #include <cstring> #include <iostream> #include <algorithm> using namespace std; const long long mod = 1e9 + 7; typedef long long ll; ll n, g[1010], k, q[1010], sizee[1010], a[1010], f[1010][1010], head[1010], to[2020], nextt[2020], tot = 1; void add(ll x, ll y) { to[tot] = y; nextt[tot] = head[x]; head[x] = tot++; } void dfs(ll u, ll fa) { f[u][a[u]] = 1; sizee[u] = 1; for (ll i = head[u]; i; i = nextt[i]) { ll v = to[i]; if (v == fa) continue; dfs(v, u); sizee[u] += sizee[v]; for (ll j = 0; j <= k; j++) { g[j] = q[sizee[v] - 1] * f[u][j] % mod; for (ll kk = 0; kk <= j; kk++) { g[j] += f[v][kk] * f[u][j - kk] % mod; g[j] %= mod; } } for (ll j = 0; j <= k; j++) f[u][j] = g[j]; } } int main() { scanf("%lld%lld", &n, &k); for (ll i = 1; i <= n; i++) scanf("%lld", &a[i]); for (ll i = 1; i < n; i++) { ll x, y; scanf("%lld%lld", &x, &y); add(x, y); add(y, x); } q[0] = 1; q[1] = 2; for (ll i = 2; i <= n; i++) q[i] = q[i - 1] * 2 % mod; dfs(1, 0); printf("%lld ", f[1][k]); return 0; }
AC:
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> const int mod = 1e9 + 7; using namespace std; typedef long long ll; ll n, k, a[1010],sizee[1010], q[1010],f[1010][1010], head[1010], to[2020], nextt[2020], tot = 1; void add(int x, int y) { to[tot] = y; nextt[tot] = head[x]; head[x] = tot++; } void dfs(int u, int fa) { sizee[u] = 1; for (int i = head[u]; i; i = nextt[i]) { int v = to[i]; if (v == fa) continue; for (int j = 0; j + a[v] <= n; j++) f[v][j + a[v]] = f[u][j]; dfs(v, u); sizee[u] += sizee[v]; for (int j = 0; j <= n; j++) f[u][j] = (q[sizee[v] - 1] * f[u][j] % mod + f[v][j]) % mod; } } int main() { scanf("%lld%lld", &n, &k); for (int i = 1; i <= n; i++) scanf("%lld", &a[i]); for (int i = 1; i < n; i++) { ll x, y; scanf("%lld%lld", &x, &y); add(x, y); add(y, x); } f[1][a[1]] = 1; q[0] = 1; for (int i = 1; i <= n; i++) q[i] = q[i - 1] * 2 % mod; dfs(1, 0); printf("%lld ", f[1][k]); return 0; }