题意
给你一棵由 N 个节点构成的树 T。节点按照 1 到 N 编号,每个节点要么是白色,要么是黑色。有 Q 组询问,每组询问形如 (s, b)。你需要检查是否存在一个连通子图,其大小恰好是 s,并且包含恰好 b 个黑色节点。
数据
输入第一行,包含一个整数 T,表示测试数据组数。对于每组测试数据:
第一行包含两个整数 N 和 Q,分别表示树的节点个数和询问个数。
接下来 N - 1 行,每行包含两个整数 ui 和 vi,表示在树中 ui 和 vi 之间存在一条边。
接下来一行包含 N 个整数,c1, c2, ... , cN。如果 ci 为 0 表示第 i 个节点是白色的,如果 ci 为
1 表示第 i 个节点是黑色的。
接下来 Q 行,每行包含两个整数 si 和 bi,表示一组形如 (si, bi) 的询问。
对于每组询问输出一行字符串表示答案,其中 Yes 表示存在一个符合要求的连
通子图,No 表示不存在。
1 <= T <= 5, 2 <= N <= 5e3, 1 <= Q <= 1e5, 1 <= ui, vi <= N。
0 <= ci <= 1, 0 <= bi <= N, 1 <= si <= N, bi <= si。
输入
1
9 4
4 1
1 5
1 2
3 2
3 6
6 7
6 8
9 6
0 1 0 1 0 0 1 0 1
3 2
7 3
4 0
9 5
输出
Yes
Yes
No
No
说明
对于第一组询问,包含由 {6, 7, 9} 构成的连通子图,其中恰包含两个黑色节点 7 和 9。
对于第二组询问,包含由 {1, 2, 3, 4, 6, 7, 8} 构成的连通子图,其中恰包含三个黑色节点 2,4和 7。
对于第三组询问和第四组询问,均不存在符合要求的连通子图。
题解:
观察到一个现象,对于一个子树,在子树中子图的点数确定的情况下,可行的黑点数是一个连续的区间。
那么很自然地用f[i][j]记下在子树i中用了j个点的情况下最多和最少的黑点数,背包一下就可以了。
剩下一个问题,就是复杂度。你会发现每个点对在dp的过程中都只出现一次,所以复杂度是n^2的。
#include<bits/stdc++.h> using namespace std; #pragma comment(linker, "/STACK:102400000,102400000") #define ls i<<1 #define rs ls | 1 #define mid ((ll+rr)>>1) #define pii pair<int,int> #define MP make_pair typedef long long LL; const long long INF = 1e18+1LL; const double Pi = acos(-1.0); const int N = 5e3+10, M = 1e3+20, mod = 1e9+7,inf = 2e9; int f[N][N][2],mi[N],mx[N],c[N],siz[N]; int head[N*10],t = 1; struct ss{ int next,to; }e[N * 20]; int T,Q,n; inline void add(int u,int v) { e[t].to = v; e[t].next = head[u]; head[u] = t++; } void dfs(int u,int fa) { siz[u] = 1; if(c[u])f[u][1][0] = f[u][1][1] = 1; else f[u][1][0] = f[u][1][1] = 0; for(int i = head[u]; i; i = e[i].next){ int to = e[i].to; if(to == fa) continue; dfs(to,u); for(int k = siz[u]; k >= 1; --k){ for(int j = 1; j <= siz[to]; ++j) { f[u][k+j][0] = min(f[u][k+j][0],f[u][k][0] + f[to][j][0]); f[u][k+j][1] = max(f[u][k+j][1],f[u][k][1] + f[to][j][1]); } } siz[u] += siz[to]; } for(int j = 1; j <= siz[u]; ++j) mi[j] = min(mi[j],f[u][j][0]),mx[j] = max(mx[j],f[u][j][1]); } int main() { scanf("%d",&T); while(T--) { scanf("%d%d",&n,&Q); memset(head,0,sizeof(head)); t = 1; for(int i = 0; i <= n; ++i) mi[i] = inf,mx[i] = 0; for(int i = 0; i <= n; ++i) for(int j = 0; j <= n; ++j) f[i][j][0] = inf,f[i][j][1] = 0; for(int i = 1; i < n; ++i) { int x,y; scanf("%d%d",&x,&y); add(x,y); add(y,x); } for(int i = 1; i <= n; ++i) scanf("%d",&c[i]); dfs(1,0); while(Q--) { int x,y; scanf("%d%d",&x,&y); if(mi[x] <= y && mx[x] >= y) { printf("Yes "); } else printf("No "); } } return 0; }