题目
题目链接:https://www.luogu.com.cn/problem/P3806
给定一棵有 \(n\) 个点的树,询问树上距离为 \(k\) 的点对是否存在。
思路
什么彩笔快高中才学点分治 /kk。
我们可以考虑一种暴力:枚举每一个点 \(x\),然后将 \(x\) 到每一个点 \(y\)(并且 \(x\) 到 \(y\) 的路径中不能有之前枚举到的点)的距离 \(dis[y]\) 求出来,然后对于每一个询问 \(ask[i]\),看看 \(x\) 的其他子树内是否有长度为 \(ask[i]-dis[y]\) 的边。如果有,那么该询问答案为 AYE
。
显然这种方法的时间复杂度为 \(O(n^2)\)。
发现每次取 \(x\) 的复杂度在于其子树大小。如果数退化成一条链且 \(x\) 每次取的是链最左右两端的点的话,那么时间复杂度就退化至 \(O(n^2)\)。
所以考虑每次取 \(x\) 时取剩余部分的重心,这样就使得每一棵子树大小都不超过那一部分的一半。那么这样的时间复杂度就降至 \(O(n\log n)\)。
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=10010,M=110,K=10000010,Inf=1e9;
int n,m,tot,sum,rt,dis[N],maxp[N],head[N],size[N];
bool vis[N],used[K];
struct edge
{
int next,to,dis;
}e[N*2];
struct Query
{
int k,ans;
}ask[M];
void add(int from,int to,int dis)
{
e[++tot].to=to;
e[tot].dis=dis;
e[tot].next=head[from];
head[from]=tot;
}
void findrt(int x,int fa)
{
size[x]=1; maxp[x]=0;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa && !vis[v])
{
findrt(v,x);
size[x]+=size[v]; maxp[x]=max(maxp[x],size[v]);
}
}
maxp[x]=max(maxp[x],sum-size[x]);
if (maxp[x]<maxp[rt]) rt=x;
}
int dfs(int x,int fa,int d)
{
dis[++tot]=d; size[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (v!=fa && !vis[v])
size[x]+=dfs(v,x,d+e[i].dis);
}
return size[x];
}
void calc(int x)
{
tot=1; dis[1]=0; used[0]=1;
for (int i=head[x],last=0;~i;i=e[i].next)
{
int v=e[i].to;
if (!vis[v]) dfs(v,x,e[i].dis);
for (int i=last+1;i<=tot;i++)
for (int j=1;j<=m;j++)
if (!ask[j].ans && ask[j].k>=dis[i])
ask[j].ans=used[ask[j].k-dis[i]];
for (int i=last+1;i<=tot;i++)
if (dis[i]<K) used[dis[i]]=1;
last=tot;
}
for (int i=0;i<=tot;i++)
if (dis[i]<K) used[dis[i]]=0;
}
void solve(int x)
{
calc(x);
vis[x]=1;
for (int i=head[x];~i;i=e[i].next)
{
int v=e[i].to;
if (!vis[v])
{
rt=0; maxp[0]=Inf; sum=size[v];
findrt(v,x);
solve(rt);
}
}
}
int main()
{
memset(head,-1,sizeof(head));
scanf("%d%d",&n,&m);
for (int i=1,x,y,z;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z); add(y,x,z);
}
for (int i=1;i<=m;i++)
scanf("%d",&ask[i].k);
maxp[0]=Inf; sum=n;
findrt(1,0);
solve(rt);
for (int i=1;i<=m;i++)
if (ask[i].ans) printf("AYE\n");
else printf("NAY\n");
return 0;
}