zoukankan      html  css  js  c++  java
  • KD-Tree总结

    KD-Tree总结

    问题引入

    平面上有(n)个点,(q)组询问,每一次查询距离((x,y))最近的点对,强制在线。

    问题解决

    暴力

    显然我们可以直接枚举点然后算距离取(min),这样子复杂度是(Theta(nq))的。

    KD-Tree

    (KD-Tree)就是一个解决这种问题的利器

    我们不妨从这个平面中选出一些点把平面分割成两个部分,那么所有的点就会在一段范围内对吧。我们只需要暴力的找每一段里面的即可。

    但是这样子复杂度还是不对,还是(Theta(nq))的,此时我们需要把每一块区域的范围给算出来,判断边界到这个点的距离是不是比(ans)小,即搜索的过程中进行乐观估价剪枝。

    当然还可以加一些搜索顺序的优化,具体实现参见代码。

    至此,(KD-Tree)的大致流程就讲完了,虽然我觉得我自己都看着一脸懵逼,但是结合代码你可能可以获得更好的阅读体验。

    更进一步

    如果这个时候可以插入点或者删除点呢?

    插入

    直接暴力插进去然后按照替罪羊树那套理论重构即可。

    删除

    打个惰性删除标记然后替罪羊树那套理论重构即可。

    例题

    P2479 [SDOI2010]捉迷藏

    直接查即可,注意不能和当前点重复。

    #include<stdio.h>
    #include<stdlib.h>
    #include<string.h>
    #include<math.h>
    #include<algorithm>
    #include<queue>
    #include<set>
    #include<map>
    #include<iostream>
    using namespace std;
    #define ll long long
    #define REP(a,b,c) for(int a=b;a<=c;a++)
    #define re register
    #define file(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout)
    inline int gi(){
    	int f=1,sum=0;char ch=getchar();
    	while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    	while(ch>='0' && ch<='9'){sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
    	return f*sum;
    }
    const int N=500010,Inf=1e9+10;
    int rt,n,tot,now,ans1,ans2;
    struct node{int x[2];bool operator!=(const node &a)const{return (x[0]!=a.x[0]) || (x[1]!=a.x[1]);}}a[N];
    struct tree{int ls,rs,mn[2],mx[2];node w;}t[N];
    int newnode(){return ++tot;}
    bool cmp(node a,node b){return a.x[now]<b.x[now];}
    void update(int o){
    	for(int i=0;i<2;i++){
    		t[o].mx[i]=t[o].mn[i]=t[o].w.x[i];
    		if(t[o].ls)t[o].mx[i]=max(t[o].mx[i],t[t[o].ls].mx[i]),t[o].mn[i]=min(t[o].mn[i],t[t[o].ls].mn[i]);
    		if(t[o].rs)t[o].mx[i]=max(t[o].mx[i],t[t[o].rs].mx[i]),t[o].mn[i]=min(t[o].mn[i],t[t[o].rs].mn[i]);
    	}
    }
    int build(int l,int r,int opt){
    	if(l>r)return 0;
    	int mid=(l+r)>>1,o=newnode();now=opt;
    	nth_element(a+l,a+mid,a+r+1,cmp);t[o].w=a[mid];
    	t[o].ls=build(l,mid-1,opt^1);t[o].rs=build(mid+1,r,opt^1);
    	update(o);return o;
    }
    int getmin(int o,node now){
    	int ret=0;
    	for(int i=0;i<2;i++)ret+=max(0,now.x[i]-t[o].mx[i])+max(0,t[o].mn[i]-now.x[i]);
    	return ret;
    }
    int getmax(int o,node now){
    	int ret=0;
    	for(int i=0;i<2;i++)ret+=max(abs(now.x[i]-t[o].mx[i]),abs(now.x[i]-t[o].mn[i]));
    	return ret;
    }
    int dis(node a,node b){return abs(a.x[0]-b.x[0])+abs(a.x[1]-b.x[1]);}
    void query_max(int o,node now){
    	if(t[o].w!=now)ans1=max(ans1,dis(t[o].w,now));
    	int le=-Inf,ri=-Inf;
    	if(t[o].ls)le=getmax(t[o].ls,now);
    	if(t[o].rs)ri=getmax(t[o].rs,now);
    	if(le>ri){
    		if(le>ans1)query_max(t[o].ls,now);
    		if(ri>ans1)query_max(t[o].rs,now);
    	}
    	else{
    		if(ri>ans1)query_max(t[o].rs,now);
    		if(le>ans1)query_max(t[o].ls,now);
    	}
    }
    void query_min(int o,node now){
    	if(t[o].w!=now)ans2=min(ans2,dis(t[o].w,now));
    	int le=Inf,ri=Inf;
    	if(t[o].ls)le=getmin(t[o].ls,now);
    	if(t[o].rs)ri=getmin(t[o].rs,now);
    	if(le<ri){
    		if(le<ans2)query_min(t[o].ls,now);
    		if(ri<ans2)query_min(t[o].rs,now);
    	}
    	else{
    		if(ri<ans2)query_min(t[o].rs,now);
    		if(le<ans2)query_min(t[o].ls,now);
    	}
    }
    int main(){
    	n=gi();
    	for(int i=1;i<=n;i++)a[i].x[0]=gi(),a[i].x[1]=gi();
    	rt=build(1,n,0);int ans=Inf;
    	for(int i=1;i<=n;i++){
    		ans1=-Inf,ans2=Inf;
    		query_max(rt,a[i]);query_min(rt,a[i]);
    		ans=min(ans,ans1-ans2);
    	}
    	printf("%d
    ",ans);
    	return 0;
    }
    

    例题2

    P4169 [Violet]天使玩偶/SJY摆棋子

    插入的模板题,按照上文所述的方法做即可。

    #include<stdio.h>
    #include<stdlib.h>
    #include<string.h>
    #include<math.h>
    #include<algorithm>
    #include<queue>
    #include<set>
    #include<map>
    #include<iostream>
    using namespace std;
    #define ll long long
    #define REP(a,b,c) for(int a=b;a<=c;a++)
    #define re register
    #define file(a) freopen(a".in","r",stdin);freopen(a".out","w",stdout)
    inline int gi(){
    	int f=1,sum=0;char ch=getchar();
    	while(ch>'9' || ch<'0'){if(ch=='-')f=-1;ch=getchar();}
    	while(ch>='0' && ch<='9'){sum=(sum<<3)+(sum<<1)+ch-'0';ch=getchar();}
    	return f*sum;
    }
    const int N=2000010,Inf=1e9+10;
    const double alpha=0.75;
    struct node{int x[2];}a[N];
    struct tree{int mx[2],mn[2],siz,ls,rs;node w;}t[N];
    int n,m,rt,tot,now,ans;
    int sta[N],top;
    int newnode(){if(top)return sta[top--];else return ++tot;}
    bool cmp(node a,node b){return a.x[now]<b.x[now];}
    void update(int o){
    	for(int i=0;i<2;i++){
    		t[o].mx[i]=t[o].mn[i]=t[o].w.x[i];
    		if(t[o].ls)t[o].mx[i]=max(t[o].mx[i],t[t[o].ls].mx[i]),t[o].mn[i]=min(t[o].mn[i],t[t[o].ls].mn[i]);
    		if(t[o].rs)t[o].mx[i]=max(t[o].mx[i],t[t[o].rs].mx[i]),t[o].mn[i]=min(t[o].mn[i],t[t[o].rs].mn[i]);
    	}
    	t[o].siz=t[t[o].ls].siz+t[t[o].rs].siz+1;
    }
    int build(int l,int r,int opt){
    	if(l>r)return 0;
    	int mid=(l+r)>>1,o=newnode();now=opt;
    	nth_element(a+l,a+mid,a+r+1,cmp);t[o].w=a[mid];
    	t[o].ls=build(l,mid-1,opt^1);t[o].rs=build(mid+1,r,opt^1);
    	update(o);return o;
    }
    void get(int o,int cnt){
    	if(t[o].ls)get(t[o].ls,cnt);
    	a[cnt+t[t[o].ls].siz+1]=t[o].w;sta[++top]=o;
    	if(t[o].rs)get(t[o].rs,cnt+t[t[o].ls].siz+1);
    }
    void check(int &o,int opt){
    	if(t[o].siz*alpha<t[t[o].ls].siz || t[o].siz*alpha<t[t[o].rs].siz)
    		get(o,0),o=build(1,t[o].siz,opt);
    }
    void insert(int &o,node now,int opt){
    	if(!o){o=newnode();t[o].w=now;t[o].ls=t[o].rs=0;update(o);return;}
    	if(now.x[opt]<=t[o].w.x[opt])insert(t[o].ls,now,opt^1);
    	else insert(t[o].rs,now,opt^1);
    	update(o);check(o,opt);
    }
    int getdis(node now,int o){
    	int ret=0;
    	for(int i=0;i<2;i++)ret+=max(now.x[i]-t[o].mx[i],0)+max(t[o].mn[i]-now.x[i],0);
    	return ret;
    }
    int dis(node a,node b){return abs(a.x[0]-b.x[0])+abs(a.x[1]-b.x[1]);}
    void query(int o,node now){
    	ans=min(ans,dis(t[o].w,now));
    	int le=Inf,ri=Inf;
    	if(t[o].ls)le=getdis(now,t[o].ls);
    	if(t[o].rs)ri=getdis(now,t[o].rs);
    	if(le<ri){
    		if(le<ans)query(t[o].ls,now);
    		if(ri<ans)query(t[o].rs,now);
    	}
    	else{
    		if(ri<ans)query(t[o].rs,now);
    		if(le<ans)query(t[o].ls,now);
    	}
    }
    int main(){
    	n=gi();m=gi();
    	for(int i=1;i<=n;i++)a[i].x[0]=gi(),a[i].x[1]=gi();
    	rt=build(1,n,0);
    	while(m--){
    		int opt=gi();node now;now.x[0]=gi();now.x[1]=gi();
    		if(opt==1)insert(rt,now,0);
    		else{ans=Inf;query(rt,now);printf("%d
    ",ans);}
    	}
    	return 0;
    }
    

    参考文献

    儿子的Blog
    感谢儿子对我的滋磁!

  • 相关阅读:
    树莓派系统Raspbian安装小结
    树莓派安装centos 7系统
    Ubuntu下安装SSH服务
    使用xUnit为.net core程序进行单元测试(4)
    使用xUnit为.net core程序进行单元测试(3)
    使用xUnit为.net core程序进行单元测试 -- Assert
    使用xUnit为.net core程序进行单元测试(1)
    用 Identity Server 4 (JWKS 端点和 RS256 算法) 来保护 Python web api
    asp.net core 2.0 查缺补漏
    "软件随想录" 读书笔记
  • 原文地址:https://www.cnblogs.com/fexuile/p/11689429.html
Copyright © 2011-2022 走看看