这道题细节是真的多
看数据范围,这应该是一道虚树DP,我们先来想一下不用虚树怎么做
我们定义(id[i])为第i个点应该归哪一个议事处管理,且i到(id[i])的距离为(dis[i])
我们做两遍dfs,首先从下到上,用儿子更新父亲,再从上到下,用父亲更新儿子
更新过程十分简单,就类似于重链剖分的思路去更新就好了
然后这里要注意第一个细节,就是必须先用儿子更新父亲,再用父亲更新儿子,因为如果不这么做的话一个父亲可能有多个儿子,所以先更新儿子的话该儿子的'兄弟'不会更新到(画画图就理解了)
暴力DP就是这么做,那如果放在虚树上呢?
对于虚树上的点,我们仍然可以按照上述暴力DP方式来做
那么非虚树上的点呢?
首先虚树是保证了两个相邻的树点在原树中实在一条链上的
所以我们可以里用倍增的思想,求出两个相邻虚树点的分界点,分界点以上归上面的点管理,分界点一下同理
由于我们还要保证编号最小,所以这就是本题第二个坑点:我们需要判断两个虚树点id值的大小
具体实现中我们可以把分界处以下的点染成一种新的颜色,然后递归处理,我们可以保证先处理深度小的点再处理深度大的点
为什么要这么做呢?
因为分界点以下的点已经不归上方关键点管辖,所以我们可以用一种新的颜色覆盖点原来的颜色,表示被一个新点占领了
楼上(chenkehan)大佬给了十分形象的图片
于是就可以愉快的码码码了
献上十分丑陋的代码:
#include<bits/stdc++.h>
using namespace std;
#define il inline
#define re register
il int read() {
re int x = 0, f = 1; re char c = getchar();
while(c < '0' || c > '9') { if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
return x * f;
}
#define rep(i, s, t) for(re int i = s; i <= t; ++ i)
#define drep(i, s, t) for(re int i = t; i >= s; -- i)
#define Next(i, u) for(re int i = head[u]; i; i = e[i].next)
#define _ 300005
int n, m, Q, top[_], Top, st[_], son[_], size[_], dep[_], dfn[_], tot, cnt, head[_];
int id[_], ans[_], f[25][_], dis[_], vis[_], Size[_];
struct node {int a, id;}q[_];
struct edge {int v, next, w;}e[_ << 1];
il bool cmp(node a, node b) {return dfn[a.a] < dfn[b.a];}
il bool cmp1(node a, node b) {return a.id < b.id;}
il void add(int u, int v, int w) {
e[++ cnt] = (edge){v, head[u], w}, head[u] = cnt;
e[++ cnt] = (edge){u, head[v], w}, head[v] = cnt;
}
il void dfs1(int u, int fr) {
f[0][u] = fr, dep[u] = dep[fr] + 1, size[u] = 1, dfn[u] = ++ tot;
Next(i, u) if(e[i].v != fr) dfs1(e[i].v, u), size[u] += size[e[i].v];
}
il int LCA(int a, int b) {
if(dep[a] < dep[b]) swap(a, b);
drep(i, 0, 20) if(dep[a] - (1 << i) >= dep[b]) a = f[i][a];
drep(i, 0, 20) if(f[i][a] != f[i][b]) a = f[i][a], b = f[i][b];
return (a == b) ? a : f[0][a];
}
il int Dis(int a, int b) {return dep[a] + dep[b] - dep[LCA(a, b)] * 2;}
il void insert(int x) {
if(Top == 1 && x != 1) return (void)(st[++ Top] = x);
int lca = LCA(st[Top], x); if(lca == x) return;
while(Top > 1 && dep[st[Top - 1]] > dep[lca])
add(st[Top - 1], st[Top], Dis(st[Top - 1], st[Top])), -- Top;
if(dep[st[Top]] > dep[lca]) add(st[Top], lca, Dis(st[Top], lca)), -- Top;
if(dep[st[Top]] < dep[lca]) st[++ Top] = lca;
st[++ Top] = x;
}
il void dfs_mem(int u, int fr) {
Next(i, u) if(e[i].v != fr) dfs_mem(e[i].v, u);
head[u] = vis[u] = id[u] = dis[u] = Size[u] = ans[u] = 0;
}
il int get_fa(int u, int dis) {
int now = 0;
drep(i, 0, 20) if(now + (1 << i) <= dis) now += (1 << i), u = f[i][u];
return u;
}
il void dfs_get(int u, int fr) {
if(vis[u]) dis[u] = 0, id[u] = u, Size[u] = size[u];
else dis[u] = 123456789, Size[u] = size[u];
Next(i, u) {
int v = e[i].v, w = e[i].w; if(v == fr) continue;
dfs_get(v, u), w += dis[v];
if(dis[u] > w || (dis[u] == w && id[u] > id[v])) dis[u] = w, id[u] = id[v];
}
}
il void dfs1_get(int u, int fr) {
Next(i, u) {
int v = e[i].v, w = dis[u] + e[i].w; if(v == fr) continue;
if(w < dis[v] || (w == dis[v] && id[v] > id[u])) dis[v] = w, id[v] = id[u];
dfs1_get(v, u);
if(id[u] == id[v]) Size[u] -= size[v];
else {
int x = get_fa(v, (dis[v] + dis[u] + e[i].w + (id[u] > id[v]) - 1) / 2 - dis[v]);//Attention: (dep[x] - dep[y])即e[i].w的意义是经过的边的数量!我就是因为这里调了**的
Size[v] += size[x] - size[v], Size[u] -= size[x];
}
ans[id[v]] += Size[v];
}
if(u == 1) ans[id[1]] += Size[1];
}
int main() {
n = read();
rep(i, 1, n - 1) add(read(), read(), 0);
Q = read(), dfs1(1, 1), memset(head, 0, sizeof(head));
rep(i, 1, 20) rep(j, 1, n) f[i][j] = f[i - 1][f[i - 1][j]];
while(Q --) {
m = read(), st[Top = 1] = 1, cnt = 0;
rep(i, 1, m) q[i].a = read(), vis[q[i].a] = 1, q[i].id = i;
sort(q + 1, q + m + 1, cmp);
rep(i, 1, m) insert(q[i].a);
while(Top > 1) add(st[Top - 1], st[Top], Dis(st[Top - 1], st[Top])), -- Top;
dfs_get(1, 0), dfs1_get(1, 0), sort(q + 1, q + m + 1, cmp1);
rep(i, 1, m) printf("%d ", ans[q[i].a]);
puts(""), dfs_mem(1, 0);
}
return 0;
}