题目
你有一棵 (n) 节点的树 ,回答 (m) 个询问,每次询问给你两个整数 (l,r) ,问存在多少个整数 (k) 使得从 (l) 沿着 (l o r) 的简单路径走 (k) 步恰好到达 (k) 。
分析
考虑离线后按链记贡献
从 (l) 到 (lca(l,r)) 这段链上,可以计入贡献的点 (x) 满足 (dep[l]-x=dep[x]),称为一类贡献
即 (dep[x]+x=dep[l]), 因为已知 (dep[l]),所以直接开桶计算
从 (lca(l,r)) 到 (r) 这段链上,可以计入贡献的点 (x) 满足 (dep[lca]+(x-dep[l]-dep[lca])=dep[x]),称为二类贡献
即 (dep[x]-x=2 imes dep[lca]-dep[l]),同样可以直接开另一个桶计算
因为 (dfs) 下来时桶记录的是根到当前点的信息,所以算贡献的时候要减去 (lca) 处的假贡献
(lca) 也可能成为需要贡献,所以算二类贡献的时候减去 (father_{lca}) 处的贡献
具体细节体现在代码
(Code)
#include<cstdio>
#include<vector>
using namespace std;
const int N = 3e5 + 5;
int n, m, dep[N], d[2][2*N], fa[N], da[N], vis[N], l[N], r[N], lca[N], ans[N];
vector<int> e[N];
struct node1{int x, id;};
vector<node1> q1[N];
struct node2{int cs, ty, f, id;};
vector<node2> q2[N];
int find(int x){return fa[x] == x ? x : fa[x] = find(fa[x]);}
void dfs(int x, int dad)
{
da[x] = dad, dep[x] = dep[dad] + 1;
for(register int i = 0; i < e[x].size(); i++)
{
if (e[x][i] == dad) continue;
dfs(e[x][i], x);
}
}
void dfs1(int x, int dad)
{
vis[x] = 1;
for(register int i = 0; i < e[x].size(); i++)
{
if (e[x][i] == dad) continue;
dfs1(e[x][i], x), fa[e[x][i]] = x;
}
for(register int i = 0; i < q1[x].size(); i++)
if (vis[q1[x][i].x]) lca[q1[x][i].id] = find(q1[x][i].x);
}
void dfs2(int x, int dad)
{
++d[0][dep[x] + x], ++d[1][dep[x] - x + n];
for(register int i = 0; i < q2[x].size(); i++)
ans[q2[x][i].id] += q2[x][i].f * d[q2[x][i].ty][q2[x][i].cs];
for(register int i = 0; i < e[x].size(); i++)
{
if (e[x][i] == dad) continue;
dfs2(e[x][i], x);
}
--d[0][dep[x] + x], --d[1][dep[x] - x + n];
}
int main()
{
freopen("query.in" , "r" , stdin);
freopen("query.out" , "w" , stdout);
scanf("%d%d" , &n , &m);
int x , y;
for(register int i = 1; i < n; i++)
{
scanf("%d%d" , &x , &y);
e[x].push_back(y), e[y].push_back(x);
}
for(register int i = 1; i <= m; i++)
{
scanf("%d%d" , &l[i], &r[i]);
q1[l[i]].push_back(node1{r[i], i});
q1[r[i]].push_back(node1{l[i], i});
}
for(register int i = 1; i <= n; i++) fa[i] = i;
dfs(1, 0), dfs1(1, 0);
for(register int i = 1; i <= m; i++)
{
q2[l[i]].push_back(node2{dep[l[i]], 0, 1, i});
q2[lca[i]].push_back(node2{dep[l[i]], 0, -1, i});
q2[r[i]].push_back(node2{2*dep[lca[i]]-dep[l[i]]+n, 1, 1, i});
if (lca[i] > 1)
q2[da[lca[i]]].push_back(node2{2*dep[lca[i]]-dep[l[i]]+n, 1, -1, i});
}
dfs2(1, 0);
for(register int i = 1; i <= m; i++) printf("%d
" , ans[i]);
}