题目背景
感谢hzwer的点分治互测。
题目描述
给定一棵有n个点的树
询问树上距离为k的点对是否存在。
输入输出格式
输入格式:
n,m 接下来n-1条边a,b,c描述a到b有一条长度为c的路径
接下来m行每行询问一个K
输出格式:
对于每个K每行输出一个答案,存在输出“AYE”,否则输出”NAY”(不包含引号)
输入输出样例
输入样例#1:
2 1 1 2 2 2
输出样例#1:
AYE
说明
对于30%的数据n<=100
对于60%的数据n<=1000,m<=50
对于100%的数据n<=10000,m<=100,c<=1000,K<=10000000
套用点分治的模板。
设num[k]为树中长度为k的路径的出现次数。
对于一个点no,可以dfs遍历一遍以它为根的子树中、以它开头向下的链长度。
对于每一对以no开头的链长度d[i],d[j],num[d[i]+d[j]]++即可。
当然要记得容斥,因为要去掉d[i],d[j]表示的链来自no的同一个儿子。
设no与一个儿子相连边的长度为dis,在递归处理这个儿子的时候,对儿子的每一对d[i],d[j],做num[d[i]+d[j]+dis*2]--即可。
1 #include<cstdio> 2 #include<cstring> 3 #include<iostream> 4 #define rint register int 5 using namespace std; 6 7 int read(){ 8 char ch; 9 int re=0; 10 bool flag=0; 11 while((ch=getchar())!='-'&&(ch<'0'||ch>'9')); 12 ch=='-'?flag=1:re=ch-'0'; 13 while((ch=getchar())>='0'&&ch<='9') re=re*10+ch-'0'; 14 return flag?-re:re; 15 } 16 17 struct Edge{ 18 int to,nxt,w; 19 Edge(int to=0,int nxt=0,int w=0): 20 to(to),nxt(nxt),w(w){} 21 }; 22 23 const int maxn=10005; 24 25 int n,m,cnt=0,sum,tot,root; 26 int head[maxn],son[maxn],F[maxn],num[10000005],d[maxn]; 27 bool vis[maxn]; 28 Edge E[maxn<<1]; 29 30 inline void a_ed(int from,int to,int w){ 31 E[++cnt]=Edge(to,head[from],w); 32 head[from]=cnt; 33 E[++cnt]=Edge(from,head[to],w); 34 head[to]=cnt; 35 } 36 37 void init(){ 38 n=read(); m=read(); 39 for(rint i=1,from,to,w;i<n;i++){ 40 from=read(); to=read(); w=read(); 41 a_ed(from,to,w); 42 } 43 } 44 45 void getroot(int no,int fa){ 46 son[no]=1; F[no]=0; 47 for(rint e=head[no];e;e=E[e].nxt){ 48 int nt=E[e].to; 49 if(nt==fa||vis[nt]) continue; 50 getroot(nt,no); 51 son[no]+=son[nt]; 52 F[no]=max(F[no],son[nt]); 53 } 54 F[no]=max(F[no],sum-son[no]); 55 if(F[no]<F[root]) root=no; 56 } 57 58 void getdeep(int no,int fa,int dd){ 59 d[tot++]=dd; 60 for(rint e=head[no];e;e=E[e].nxt){ 61 int nt=E[e].to; 62 if(nt==fa||vis[nt]) continue; 63 getdeep(nt,no,dd+E[e].w); 64 } 65 } 66 67 void calc(int no,bool opt,int p){ 68 tot=0; 69 getdeep(no,0,0); 70 for(int i=0;i<tot;i++) 71 for(int j=0;j<tot;j++) 72 if(opt) num[d[i]+d[j]]++; 73 else num[d[i]+d[j]+p]--; 74 } 75 76 void solve(int no){ 77 vis[no]=1; 78 calc(no,1,0); 79 for(rint e=head[no];e;e=E[e].nxt){ 80 int nt=E[e].to; 81 if(vis[nt]) continue; 82 calc(nt,0,E[e].w<<1); 83 sum=son[nt]; root=0; 84 getroot(nt,0); 85 solve(nt); 86 } 87 } 88 89 int main(){ 90 init(); 91 sum=F[root=0]=n; 92 getroot(1,0); 93 solve(root); 94 for(rint i=0,q;i<m;i++){ 95 q=read(); 96 if(num[q]) puts("AYE"); 97 else puts("NAY"); 98 } 99 return 0; 100 }