QwQ点分治这个东西 还是很有意思的
点分治主要是用来解决一些树上路径问题
首先,我们要明确点分治的分治标准是重心
什么是重心?
如果以(x)为根,所有子树的最大(size)最小,那么称(x)是这棵树的重心
那么怎么去找重心呢
我们直接(dfs),然后对于每个点,求一个(mx[i])表示,以(i)为根的子树的最大的(siz)
int getroot(int x,int fa)
{
siz[x]=1;
mx[x]=0;
for (int i=point[x];i;i=nxt[i])
{
int p = to[i];
if (vis[p] || p==fa) continue;
getroot(p,x);
siz[x]+=siz[p];
mx[x]=max(mx[x],siz[p]);
}
mx[x]=max(mx[x],n-siz[x]);
if (mx[x]<mx[root]) root=x;
}
下面,我们来介绍点分治的过程
每次实际上就是重复这样的一个过程
每次找到当前子树的重心,然后求过这个重心的路径的贡献,然后容斥一下, 减去会被重复计算的贡献,然后再分别递归当前重心节点的所有子树,重复这个过程
于是每一次找到重心,递归的子树大小是不超过原树大小的一半的,那么递归层数不会超过(O(logn))层,时间复杂度为(O(nlogn))
这个时间复杂度的分析,我是用调和剂数来做的,你考虑对于这个总的循环次数 应该(n+frac{n}{2}*2 + frac{n}{4}*4cdots),如果把n提出来,后面自然就是个(logn),所以总复杂度是(O(nlog n))
那么回到这个题目
我们可以统计出每个长度的路径条数,然后针对询问(O(1))回答。
很显然的是,我们可以对每个重心开始(dfs),然后两重循环枚举点,将(sum[dis[i]+dis[j]]++),可以用上面同样的方法证明这个复杂度是(O(n^2 logn))的
那么这里就会出现不合法的路径,也就是一条边经过两个的那种,我们只需要把其他点的(dis+len[i]),这样再枚举点的时候,就强制默认了重复走了那条边,然后把(sum[dis[i]+dis[j]--)就行
一些细节还是直接看代码吧
不过有要注意的地方就是,要把已经计算过的重心打上标记,这样不会在(dfs)的时候,走到其他子树里
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int maxn = 1e5+1e2;
const int maxm = maxn*3;
int point[maxn],nxt[maxm],to[maxm];
int dis[maxn],son[maxn],siz[maxn];
int root,n,m,cnt;
int val[maxn];
int num;
int mx[maxn];
int vis[maxn];
int sum[10000000];
void addedge(int x,int y,int w)
{
nxt[++cnt]=point[x];
to[cnt]=y;
val[cnt]=w;
point[x]=cnt;
}
void insert(int x,int y,int w)
{
addedge(x,y,w);
addedge(y,x,w);
}
int getroot(int x,int fa)
{
siz[x]=1;
mx[x]=0;
for (int i=point[x];i;i=nxt[i])
{
int p = to[i];
if (vis[p] || p==fa) continue;
getroot(p,x);
siz[x]+=siz[p];
mx[x]=max(mx[x],siz[p]);
}
mx[x]=max(mx[x],n-siz[x]);
if (mx[x]<mx[root]) root=x;
}
void getdis(int x,int fa,int len)
{
dis[++num]=len;
for (int i=point[x];i;i=nxt[i])
{
int p = to[i];
if (p==fa || vis[p]) continue;
getdis(p,x,len+val[i]);
}
}
void solve(int x,int len)
{
num=0;
getdis(x,0,len);
if (len!=0)
{
for (int i=1;i<=num;i++)
for (int j=i+1;j<=num;j++)
sum[dis[i]+dis[j]]--;
}
else
{
for (int i=1;i<=num;i++)
for (int j=i+1;j<=num;j++)
sum[dis[i]+dis[j]]++;
}
}
void dfs(int x)
{
vis[x]=1;
solve(x,0);
for (int i=point[x];i;i=nxt[i]){
int p = to[i];
if (vis[p]) continue;
solve(p,val[i]);
root=0;
n=siz[p];root=0;
getroot(p,0);
//cout<<root<<" "<<sum[2]<<endl;
dfs(root);
}
}
int main()
{
n=read(),m=read();
for (int i=1;i<n;i++) {
int x,y,w;
x=read(),y=read(),w=read();
insert(x,y,w);
}
mx[root]=2e9;
getroot(1,0);
//cout<<root;
dfs(root);
for (int i=1;i<=m;i++)
{
int x = read();
if (sum[x]) cout<<"AYE"<<endl;
else cout<<"NAY"<<endl;
}
return 0;
}