题目链接
题意:
外星人的母舰可以看成是一棵 n 个节点、 n−1 条边的无向树,树上的节点用 1,2,⋯,n 编号。JYY 的特工已经装备了隐形模块,可以在外星人母舰中不受限制地活动,可以神不知鬼不觉地在节点上安装监听设备。
如果在节点 u 上安装监听设备,则 JYY 能够监听与 u 直接相邻所有的节点的通信。换言之,如果在节点 u 安装监听设备,则对于树中每一条边 (u,v) ,节点 v 都会被监听。
特别注意放置在节点 u 的监听设备并不监听 u 本身的通信,这是 JYY 特别为了防止外星人察觉部署的战术。
JYY 的特工一共携带了 k 个监听设备,现在 JYY 想知道,有多少种不同的放置监听设备的方法,能够使得母舰上所有节点的通信都被监听?为了避免浪费,每个节点至多只能安装一个监听设备,且监听设备必须被用完。
(nleq 100000 ,kleq 100)。
显然是树形背包DP。
但是,状态比较难设计。如果u没有被监视,则u的子节点必须至少有一个选。所以要加一维表示选不选。
而如果u被监视了,则u的子节点可以都不选。所以要加一维表示u是否被监视。
这样就好理解了。
f1[a+b][0]=(f1[a+b][0]+1ll*x0[a][0]*dp[v[i]][b][0][0])%md;
f1[a+b][1]=(f1[a+b][1]+1ll*x0[a][0]*dp[v[i]][b][0][1]+1ll*x0[a][1]*(dp[v[i]][b][0][0]+dp[v[i]][b][0][1]))%md;
f2[a+b][0]=(f2[a+b][0]+1ll*x1[a][0]*dp[v[i]][b][1][0])%md;
f2[a+b][1]=(f2[a+b][1]+1ll*x1[a][0]*dp[v[i]][b][1][1]+1ll*x1[a][1]*(dp[v[i]][b][1][1]+dp[v[i]][b][1][0]))%md;
关键是复杂度。
首先,常规树形背包是(O(n^2))的。
就是每对点会在lca处贡献复杂度。
但是,这个算法,最初觉得是(O(nk^2))的,实际上是(O(nk))的。
证明:
- 根据正常树形背包的复杂度(O(n^2)),小于等于k的最多产生(n/k*k^2)的复杂度。
- 大于k与大于k的合并一次,被合并的就增加k,最多n/k次,最多产生(n/k*k^2)的复杂度。
- 大于k的与小于等于k的合并时,每个小于等于k的最多被合并一次,所以是(n*s_1+n*s_2+...+n*s_m),也是(nk)。
还有一种理解,不知道对不对:
把树按照dfs序变为序列。
然后,在子树中枚举取x个,可以理解为取dfs序的前(后)x个。
而合并时,认为一棵子树取后x个,另一棵取前y个。((x+yleq k))。这可以合并为长x+y的区间。
这其实就是长度不大于k的子串,最多有nk个。
但是,因为有取0个的情况,所以实际做题时,大约有2的常数。但那个常数就忽略了可以。
代码
#include <stdio.h>
#define min(a, b)(a < b ? a: b)
#define md 1000000007
inline int read() {
char ch;
while ((ch = getchar()) < '0' || ch > '9');
int rt = (ch ^ 48);
while ((ch = getchar()) >= '0' && ch <= '9') rt = (rt << 3) + (rt << 1) + (ch ^ 48);
return rt;
}
int dp[100002][102][2][2],f1[102][2],f2[102][2];
int x0[102][2],x1[102][2],sz[100002];
int fr[100002],ne[200002],v[200002],bs = 0,k;
void addb(int a, int b) {
v[bs] = b;
ne[bs] = fr[a];
fr[a] = bs++;
}
void dfs(int u, int fu) {
int si = 0;
for (int i = fr[u]; i != -1; i = ne[i]) {
if (v[i] != fu) {
dfs(v[i], u);
si += sz[v[i]];
}
}
for (int i = 0; i <= min(k, si); i++) x0[i][0] = x1[i][0] = 0;
x0[0][0] = x1[0][0] = 1;
si = 0;
for (int i = fr[u]; i != -1; i = ne[i]) {
if (v[i] == fu) continue;
int rt = sz[v[i]];
for (int a = 0; a <= min(si, k); a++) {
for (int b = 0; b <= min(rt, k - a); b++) {
f1[a + b][0] = (f1[a + b][0] + 1ll * x0[a][0] * dp[v[i]][b][0][0]) % md;
f1[a + b][1] = (f1[a + b][1] + 1ll * x0[a][0] * dp[v[i]][b][0][1] + 1ll * x0[a][1] * (dp[v[i]][b][0][0] + dp[v[i]][b][0][1])) % md;
f2[a + b][0] = (f2[a + b][0] + 1ll * x1[a][0] * dp[v[i]][b][1][0]) % md;
f2[a + b][1] = (f2[a + b][1] + 1ll * x1[a][0] * dp[v[i]][b][1][1] + 1ll * x1[a][1] * (dp[v[i]][b][1][1] + dp[v[i]][b][1][0])) % md;
}
}
si += rt;
for (int a = 0; a <= min(si, k); a++) {
x0[a][0] = f1[a][0];
x0[a][1] = f1[a][1];
x1[a][0] = f2[a][0];
x1[a][1] = f2[a][1];
f1[a][0] = f1[a][1] = f2[a][0] = f2[a][1] = 0;
}
}
for (int a = 0; a <= min(si, k); a++) {
dp[u][a][0][0] = x0[a][1];
dp[u][a][1][0] = (x0[a][0] + x0[a][1]) % md;
}
for (int a = 1; a <= min(si + 1, k); a++) {
dp[u][a][0][1] = x1[a - 1][1];
dp[u][a][1][1] = (x1[a - 1][0] + x1[a - 1][1]) % md;
}
sz[u] = si + 1;
}
int main() {
int n;
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; i++) fr[i] = -1;
for (int i = 0; i < n - 1; i++) {
int a,b;
a = read();
b = read();
addb(a, b);
addb(b, a);
}
dfs(1, 0);
printf("%d", (dp[1][k][0][0] + dp[1][k][0][1]) % md);
return 0;
}