前言
是什么危险的想法让我来做这道危险的题。
题目
题目链接:https://www.luogu.com.cn/problem/P3899
设 \(\text T\) 为一棵有根树,我们做如下的定义:
- 设 \(a\) 和 \(b\) 为 \(\text T\) 中的两个不同节点。如果 \(a\) 是 \(b\) 的祖先,那么称“\(a\) 比 \(b\) 不知道高明到哪里去了”。
- 设 \(a\) 和 \(b\) 为 \(\text T\) 中的两个不同节点。如果 \(a\) 与 \(b\) 在树上的距离不超过某个给定常数 \(x\),那么称“ \(a\) 与 \(b\) 谈笑风生”。
给定一棵 \(n\) 个节点的有根树 \(\text T\),节点的编号为 \(1\) 到 \(n\),根节点为 \(1\) 号节点。
你需要回答 \(q\) 个询问,询问给定两个整数 \(p\) 和 \(k\),问有多少个有序三元组 \((a,b,c)\) 满足:
- \(a,b,c\) 为 \(\text T\) 中三个不同的点,且 \(a\) 为 \(p\) 号节点;
- \(a\) 和 \(b\) 都比 \(c\) 不知道高明到哪里去了;
- \(a\) 和 \(b\) 谈笑风生。这里谈笑风生中的常数为给定的 \(k\)。
思路
中文翻译:询问树上有多少组点对 \((a,b,c)\) 满足 \(a,b\) 均为 \(c\) 的祖先且 \(a,b\) 两点之间距离不超过 \(k\)。要求点对中不含相同的点。
那么分成两类讨论:
- \(b\) 为 \(a\) 的祖先。
这一部分很好求,显然无论 \(b\) 取那个点,\(c\) 必然是 \(a\) 子树中的任意一个节点(除 \(a\) 外)。那么由于 \(a\) 的祖先有 \(dep[a]-1\) 个,所以这部分答案就是 \(\max (dep[a]-1,k)\times size[a]\)。 - \(a\) 为 \(b\) 的祖先。
设 \(b\) 是任意一个位于 \(a\) 的子树内且与 \(a\) 的距离不超过 \(k\) 的节点,那么符合要求的点对 \((a,b,c)\) 的 \(c\) 的取值就是 \(size[b]\) 个。
所以答案即为 \(\sum_{i(i\in \operatorname{T}_a\ |\ dis(i,a)\leq k)}size[i]\)。
对于每一个点建立一棵线段树,线段树的区间 \([l,r]\) 即为该点子树中深度在 \([l,r]\) 的点的 \(size\) 之和。这样询问可以 \(O(\log n)\) 处理。
而维护线段树则采用线段树合并即可。
时间复杂度 \(O(n\log n)\)。
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N=300010,M=10000010;
int n,Q,tot,head[N],rt[N],dep[N],size[N];
struct edge
{
int next,to;
}e[N*2];
void add(int from,int to)
{
e[++tot].to=to;
e[tot].next=head[from];
head[from]=tot;
}
struct seg
{
int tot,lc[M],rc[M];
ll ans[M];
int update(int x,int l,int r,int k,ll val)
{
if (!x) x=++tot;
ans[x]+=val;
if (l==k && r==k) return x;
int mid=(l+r)>>1;
if (k<=mid) lc[x]=update(lc[x],l,mid,k,val);
else rc[x]=update(rc[x],mid+1,r,k,val);
return x;
}
int merge(int x,int y,int l,int r)
{
if (!x || !y) return x+y;
int p=++tot,mid=(l+r)>>1;
lc[p]=merge(lc[x],lc[y],l,mid);
rc[p]=merge(rc[x],rc[y],mid+1,r);
ans[p]=ans[x]+ans[y];
return p;
}
ll query(int x,int l,int r,int ql,int qr)
{
if (l==ql && r==qr) return ans[x];
int mid=(l+r)>>1;
if (qr<=mid) return query(lc[x],l,mid,ql,qr);
else if (ql>mid) return query(rc[x],mid+1,r,ql,qr);
else return query(lc[x],l,mid,ql,mid)+query(rc[x],mid+1,r,mid+1,qr);
}
}seg;
void dfs(int x,int fa)
{
dep[x]=dep[fa]+1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa)
{
dfs(v,x);
size[x]+=size[v]+1;
// rt[x]=seg.update(rt[x],1,n,dep[v],size[v]);
// rt[x]=seg.merge(rt[x],rt[v],1,n);
}
}
rt[x]=seg.update(rt[x],1,n,dep[x],size[x]);
if (x>1) rt[fa]=seg.merge(rt[fa],rt[x],1,n);
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&Q);
for (int i=1,x,y;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
dfs(1,0);
while (Q--)
{
int p,k;
scanf("%d%d",&p,&k);
ll ans1=1LL*size[p]*min(k,dep[p]-1);
ll ans2=seg.query(rt[p],1,n,min(dep[p]+1,n),min(dep[p]+k,n));
printf("%lld\n",ans1+ans2);
}
return 0;
}