题面
题解
数据范围已经告诉我们是虚树了,考虑如何在虚树上面(dp)
以下摘自hzwer博客:
构建虚树以后两遍dp处理出虚树上每个点最近的议事处
然后枚举虚树上每一条边,考虑其对两端点的答案贡献
可以用倍增二分出分界点
如果a,b的分界点为mid,a,b路径上a的第一个儿子为x
则对a的贡献是size[x]-size[mid]
对b的贡献是size[mid]-size[b]
还要算上没被考虑的点
Code
// luogu-judger-enable-o2
#include<bits/stdc++.h>
#define LL long long
#define RG register
using namespace std;
template<class T> inline void read(T &x) {
x = 0; RG char c = getchar(); bool f = 0;
while (c != '-' && (c < '0' || c > '9')) c = getchar(); if (c == '-') c = getchar(), f = 1;
while (c >= '0' && c <= '9') x = x*10+c-48, c = getchar();
x = f ? -x : x;
return ;
}
template<class T> inline void write(T x) {
if (!x) {putchar(48);return ;}
if (x < 0) x = -x, putchar('-');
int len = -1, z[20]; while (x > 0) z[++len] = x%10, x /= 10;
for (RG int i = len; i >= 0; i--) putchar(z[i]+48);return ;
}
int n;
const int N = 300010;
struct node {
int to, next;
}g[N<<1];
int last[N], gl;
inline void add(int x, int y) {
g[++gl] = (node) {y, last[x]};
last[x] = gl;
return ;
}
int dfn[N], cnt, siz[N], dep[N], anc[N][21], rem[N], bel[N];
void init(int u, int fa) {
dfn[u] = ++cnt; siz[u] = 1;
anc[u][0] = fa;
for (int i = 1; i <= 20; i++)
anc[u][i] = anc[anc[u][i-1]][i-1];
for (int i = last[u]; i; i = g[i].next) {
int v = g[i].to;
if (v == fa) continue;
dep[v] = dep[u]+1;
init(v, u);
siz[u] += siz[v];
}
return ;
}
int lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
for (int i = 20; i >= 0; i--)
if (dep[x]-(1<<i) >= dep[y])
x = anc[x][i];
if (x == y) return x;
for (int i = 20; i >= 0; i--)
if (anc[x][i] != anc[y][i])
x = anc[x][i], y = anc[y][i];
return anc[x][0];
}
int dis(int x, int y) {
return dep[x]+dep[y]-2*dep[lca(x, y)];
}
int top, len, m, a[N], b[N], s[N], c[N], f[N];
bool cmp(int a, int b) {
return dfn[a] < dfn[b];
}
inline void insert(int x) {
if (top == 1) {s[++top] = x; return ;}
int o = lca(x, s[top]);
while (top > 1 && dfn[s[top-1]] >= dfn[o]) add(s[top-1], s[top]), top--;
if (o != s[top]) add(o, s[top]), s[top] = o;
s[++top] = x;
return ;
}
void dfs1(int x) {
c[++len] = x; rem[x] = siz[x];
for (int i = last[x]; i; i = g[i].next) {
dfs1(g[i].to);
if (!bel[g[i].to]) continue;
int t1 = dis(bel[g[i].to], x), t2 = dis(bel[x], x);
if ((t1 == t2 && bel[g[i].to] < bel[x]) || t1 < t2 || !bel[x])
bel[x] = bel[g[i].to];
}
return ;
}
void dfs2(int x) {
for (int i = last[x]; i; i = g[i].next) {
int t1 = dis(bel[x], g[i].to), t2 = dis(bel[g[i].to], g[i].to);
if ((t1 == t2 && bel[g[i].to] > bel[x]) || t1 < t2 || !bel[g[i].to])
bel[g[i].to] = bel[x];
dfs2(g[i].to);
}
return ;
}
void solve(int a, int b) {
int x = b, mid = b;
for (int i = 20; i >= 0; i--)
if (dep[anc[x][i]] > dep[a])
x = anc[x][i];
rem[a] -= siz[x];
if (bel[a] == bel[b]) {
f[bel[a]] += siz[x]-siz[b];
return ;
}
for (int i = 20; i >= 0; i--) {
int nxt = anc[mid][i];
if (dep[nxt] <= dep[a]) continue;
int t1 = dis(bel[a], nxt), t2 = dis(bel[b], nxt);
if (t1 > t2 || (t1 == t2 && bel[b] < bel[a])) mid = nxt;
}
f[bel[a]] += siz[x]-siz[mid];
f[bel[b]] += siz[mid]-siz[b];
return ;
}
void query() {
top = len = gl = 0;
read(m);
for (int i = 1; i <= m; i++) read(a[i]), b[i] = a[i];
for (int i = 1; i <= m; i++) bel[a[i]] = a[i];
sort(a+1, a+1+m, cmp);
if (bel[1] != 1) s[++top] = 1;
for (int i = 1; i <= m; i++) insert(a[i]);
for (int i = 1; i < top; i++) add(s[i], s[i+1]);
dfs1(1); dfs2(1);
for (int i = 1; i <= len; i++)
for (int j = last[c[i]]; j; j = g[j].next)
solve(c[i], g[j].to);
for (int i = 1; i <= len; i++) f[bel[c[i]]] += rem[c[i]];
for (int i = 1; i <= m; i++) write(f[b[i]]), putchar(' ');
putchar('
');
for (int i = 1; i <= len; i++) f[c[i]] = bel[c[i]] = last[c[i]] = 0;
return ;
}
int main() {
read(n);
for (int i = 1; i < n; i++) {
int x, y;
read(x); read(y);
add(x, y); add(y, x);
}
init(1, 0);
memset(last, 0, sizeof(last));
int q; read(q);
while (q--) query();
return 0;
}