一、题目:
二、思路:
这道题真的把我迷了一个下午,感谢 zYzYzYzYz 大佬的点拨,让我终于明白了这道题。
注意到如果 (y) 的一个祖先是 (x),那么先 (mathbb{access}(x)) 再 (mathbb{access}(y)) 本质上相当于只 access 了一次。所以我们可以执行一个从下到上的树形 DP。
注意,以下这个 DP 状态比较复杂。
假设我们现在的目标是求出来 (x) 相关的信息,现在轮到了用 (x) 的儿子 (y) 来去更新 (x) 的信息。此时,(tmpf[k]) 表示"保证 (y) 之前的儿子边都不是实边",操作了 (k) 次的方案数;(tmpg[k]) 表示“保证 (y) 之前的儿子边中恰好有一个是实边”,操作了 (k) 次的方案数;(f[x,k]) 表示保证 (y) 及 (y) 之前的儿子边都不是实边,操作了 (k) 次的方案数;(g[x,k]) 表示保证处理完 (y) 之后,(x) 顶上那条边是实边,操作了 (k) 次的方案数。
则有状态转移方程:
在整个更新完 (x) 的答案之后,我们将 (g[x]) 数组中的值全部赋给 (f[x]),此时 (f[x,k]) 的意义变成 (x) 顶上那条边不是实边,操作了 (k) 次的方案数。为什么可以直接将 (g[x]) 数组中的值赋给 (f[x]) 呢?因为 (x) 顶上的边不是实边,只有可能是 (x) 的父亲用了一次 access,把原本是实边的边变成了虚边。
最后有一个小细节。就是最后将 (g) 赋给 (f) 的时候,(g[x,0]) 是等于 0 的,而 (g[x,1]) 是大于 0 的。赋给 (f) 了之后,为了保证意义上的自洽以及树的形态不能重复,要将 (f[x,0]gets 1),(f[x,1]gets f[x,1]-1)。
由于 (kleq siz_x),所以这是一个经典的 (O(nK)) 的 DP。
三、代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
#define FILEIN(s) freopen(s, "r", stdin)
#define FILEOUT(s) freopen(s, "w", stdout)
#define mem(s, v) memset(s, v, sizeof s)
inline int read(void) {
int x = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
return f * x;
}
const int MAXN = 10005, MOD = 998244353, MAXK = 505;
int n, K, siz[MAXN];
int head[MAXN], tot;
long long f[MAXN][MAXK], g[MAXN][MAXK];
long long tmpf[MAXK], tmpg[MAXK];
struct Edge {
int y, next;
Edge() {}
Edge(int _y, int _next) : y(_y), next(_next) {}
}e[MAXN << 1];
inline void connect(int x, int y) {
e[++ tot] = Edge(y, head[x]);
head[x] = tot;
}
void dfs(int x, int fa) {
siz[x] = 1;
g[x][1] = 1;
f[x][0] = 1;
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].y;
if (y == fa) continue;
dfs(y, x);
for (int k = 0; k <= min(K, siz[x]); ++ k) {
tmpf[k] = f[x][k];
tmpg[k] = g[x][k];
f[x][k] = g[x][k] = 0;
}
for (int j = 0; j <= min(K, siz[x]); ++ j) {
for (int k = 0; k <= siz[y] && j + k <= K; ++ k) {
(f[x][j + k] += tmpf[j] * f[y][k] % MOD) %= MOD;
(g[x][j + k] += tmpf[j] * g[y][k] % MOD + tmpg[j] * f[y][k] % MOD) %= MOD;
}
}
siz[x] += siz[y];
}
memcpy(f[x], g[x], sizeof f[x]);
f[x][0] = 1; -- f[x][1];
}
int main() {
FILEIN("access.in"); FILEOUT("access.out");
n = read(); K = read();
for (int i = 1; i < n; ++ i) {
int x = read(), y = read();
connect(x, y); connect(y, x);
}
dfs(1, 0);
long long res = 0;
for (int i = 0; i <= min(n, K); ++ i) (res += f[1][i]) %= MOD;
printf("%lld
", res);
return 0;
}