题目链接:hdu_5314_Happy King
题意:
给出一颗n个结点的树,点上有权值;
求点对(x,y)满足x!=y且x到y的路径上最大值与最小值的差<=D;
题解:
还是树的点分治,在统计答案的时候先按到根的最小值排序,然后用最大值减D去找有多少个满足答案。
1 #include<bits/stdc++.h> 2 #define F(i,a,b) for(int i=a;i<=b;++i) 3 using namespace std; 4 typedef pair<int,int>P; 5 typedef long long ll; 6 const int N=1e5+7; 7 8 int n,k,g[N],v[N*2],nxt[N*2],ed,w[N],t; 9 int vis[N],size[N],mx[N],mi,tot,root; 10 P dis[N]; 11 ll ret; 12 13 inline void adg(int x,int y){v[++ed]=y,nxt[ed]=g[x],g[x]=ed;} 14 void init(){F(i,1,n)g[i]=0,vis[i]=0;ed=0,ret=0;} 15 16 void dfs_size(int u,int fa) 17 { 18 size[u]=1,mx[u]=0; 19 for(int i=g[u];i;i=nxt[i]) 20 if(v[i]!=fa&&!vis[v[i]]) 21 { 22 dfs_size(v[i],u),size[u]+=size[v[i]]; 23 if(size[v[i]]>mx[u])mx[u]=size[v[i]]; 24 } 25 } 26 27 void dfs_root(int r,int u,int fa) 28 { 29 if(size[r]-size[u]>mx[u])mx[u]=size[r]-size[u]; 30 if(mx[u]<mi)mi=mx[u],root=u; 31 for(int i=g[u];i;i=nxt[i]) 32 if(v[i]!=fa&&!vis[v[i]]) 33 dfs_root(r,v[i],u); 34 } 35 36 void dfs_dis(int u,int mi,int mx,int fa) 37 { 38 mi=min(mi,w[u]),mx=max(mx,w[u]); 39 if(mx<=mi+k)dis[++tot]=P(mi,mx); 40 for(int i=g[u];i;i=nxt[i]) 41 if(v[i]!=fa&&!vis[v[i]]) 42 dfs_dis(v[i],mi,mx,u); 43 } 44 45 ll calc(int u,int mi,int mx) 46 { 47 ll ans=0; 48 tot=0,dfs_dis(u,mi,mx,0); 49 sort(dis+1,dis+tot+1); 50 F(i,1,tot) 51 { 52 int p=lower_bound(dis+1,dis+i+1,P(dis[i].second-k,0))-dis; 53 ans+=i-p; 54 } 55 return ans; 56 } 57 58 void dfs(int u=1) 59 { 60 mi=n,dfs_size(u,0); 61 dfs_root(u,u,0); 62 ret+=calc(root,w[root],w[root]),vis[root]=1; 63 for(int i=g[root];i;i=nxt[i]) 64 if(!vis[v[i]])ret-=calc(v[i],w[root],w[root]); 65 for(int i=g[root];i;i=nxt[i]) 66 if(!vis[v[i]])dfs(v[i]); 67 } 68 69 int main() 70 { 71 scanf("%d",&t); 72 while(t--) 73 { 74 scanf("%d%d",&n,&k); 75 init(); 76 F(i,1,n)scanf("%d",w+i); 77 F(i,1,n-1) 78 { 79 int x,y; 80 scanf("%d%d",&x,&y); 81 adg(x,y),adg(y,x); 82 } 83 dfs(),printf("%lld ",ret*2); 84 } 85 return 0; 86 }
法2:
将统计的答案按倒根的最大值排序,如果当前最大值为这个点的最大值,那么我们就在树状数组中去找最大值-D的答案,所以我们在统计好后需要将这个点的最小值插入树状数组
1 #include<bits/stdc++.h> 2 #define F(i,a,b) for(int i=a;i<=b;++i) 3 using namespace std; 4 typedef pair<int,int>P; 5 typedef long long ll; 6 const int N=1e5+7; 7 8 int n,k,g[N],v[N*2],nxt[N*2],ed,w[N],t,hsh[N],hsh_ed; 9 int vis[N],size[N],mx[N],mi,tot,root,sum[N]; 10 P dis[N];ll ans; 11 12 inline void adg(int x,int y){v[++ed]=y,nxt[ed]=g[x],g[x]=ed;} 13 void init(){F(i,1,n)g[i]=0,vis[i]=0;ed=ans=0,hsh_ed=0;} 14 15 inline void add(int x,int c){while(x<=hsh_ed+1)sum[x]+=c,x+=x&-x;} 16 inline int ask(int x){int an=0;while(x>0)an+=sum[x],x-=x&-x;return an;} 17 inline int getid(int x){return lower_bound(hsh+1,hsh+1+hsh_ed,x)-hsh;} 18 void dfs_size(int u,int fa) 19 { 20 size[u]=1,mx[u]=0; 21 for(int i=g[u];i;i=nxt[i]) 22 if(v[i]!=fa&&!vis[v[i]]) 23 { 24 dfs_size(v[i],u),size[u]+=size[v[i]]; 25 if(size[v[i]]>mx[u])mx[u]=size[v[i]]; 26 } 27 } 28 29 void dfs_root(int r,int u,int fa) 30 { 31 if(size[r]-size[u]>mx[u])mx[u]=size[r]-size[u]; 32 if(mx[u]<mi)mi=mx[u],root=u; 33 for(int i=g[u];i;i=nxt[i]) 34 if(v[i]!=fa&&!vis[v[i]]) 35 dfs_root(r,v[i],u); 36 } 37 38 void dfs_dis(int u,int mi,int mx,int fa) 39 { 40 mi=min(mi,w[u]),mx=max(mx,w[u]); 41 if(mx<=mi+k)dis[++tot]=P(mx,mi); 42 for(int i=g[u];i;i=nxt[i]) 43 if(v[i]!=fa&&!vis[v[i]]) 44 dfs_dis(v[i],mi,mx,u); 45 } 46 47 ll calc(int u,int mi,int mx) 48 { 49 ll ans=0; 50 tot=0,dfs_dis(u,mi,mx,0); 51 sort(dis+1,dis+1+tot); 52 F(i,1,tot) 53 { 54 ans+=ask(hsh_ed)-ask(getid(dis[i].first-k)-1); 55 add(getid(dis[i].second),1); 56 } 57 F(i,1,tot)add(getid(dis[i].second),-1); 58 return ans; 59 } 60 61 void dfs(int u=1) 62 { 63 mi=n,dfs_size(u,0); 64 dfs_root(u,u,0); 65 ans+=calc(root,w[root],w[root]),vis[root]=1; 66 for(int i=g[root];i;i=nxt[i]) 67 if(!vis[v[i]]) 68 ans-=calc(v[i],w[root],w[root]); 69 for(int i=g[root];i;i=nxt[i]) 70 if(!vis[v[i]])dfs(v[i]); 71 } 72 73 int main() 74 { 75 scanf("%d",&t); 76 while(t--) 77 { 78 scanf("%d%d",&n,&k); 79 init(); 80 F(i,1,n)scanf("%d",w+i),hsh[i]=w[i]; 81 F(i,1,n-1) 82 { 83 int x,y; 84 scanf("%d%d",&x,&y); 85 adg(x,y),adg(y,x); 86 } 87 sort(hsh+1,hsh+1+n),hsh_ed=unique(hsh+1,hsh+1+n)-hsh; 88 dfs(),printf("%lld ",ans*2); 89 } 90 return 0; 91 }