KD-tree 讲解
by simb351
应神犇junble19768的要求,来水一发KD-tree讲解。学习过程中发现关于KD-tree的资源实在是少,在OI中的应用更是少之又少。
虽然这东西很水,随便嘴一嘴就能口胡出来。所以写一篇讲解。
前置芝士:
1. 基础BST姿势
2. 替罪羊树
3. 优良的空间感 比如能脑补出k维空间QWQ
KD-tree的应用:
KD-tree主要解决对K维数据的管理,比如多维偏序。但是本弱发现目前OI中KD-tree主流考法为维护二维平面中区间的信息。比如
ION 9102 弹跳 但是好像能被菜鸡simba口胡的暴力算几卡过。
KD-tree的原理:
考虑从一般BST中类比过来---对于其中一个节点其左边节点的值恒小于它本身,右边反之。实际上是把所有节点的值从中间分开。
如果说BST是对一个一维线段的分割,那么KD-tree就是对K维空间分割。最终在小的空间内统计答案。说人话就是对K维按顺序均匀分割
查哪部分就去哪个块中查找。因为是按序分割,所以找到答案空间的时间是nlogn至nsqrtn我信你个鬼,不带O2天天被卡。为了划分空间
,KD-tree在第i层维护第i%k维的信息,即这一维中比它小的在左子树,大的在右子树。对于查询就像BST一样就好了。同BST,考虑
KD-tree如何保持自身平衡。由于用方差过于优雅,此处选择替罪羊树一样的思路---拍扁重建。这样KD-tree就愉快的讲完了,撒花。
KD-tree代码实现:
首先是树的结点。
struct point { int x[DIM]; //DIM☞维度 x表示一个k维向量 bool operator < (const point X) const { return x[now]<X.x[now]; } /* 考虑分割一个维度时,为了让分割更均匀,要尽量选最中间的点 now表示当前维护维度,定义小于号来维护中间的点。*/ }// 存储一个向量 struct node { int l,r;//左右子树 int sze;//子树大小 int minn[DIM];//此节点维护的空间中第i维的最小值 int maxx[DIM];//此节点维护的空间中第i维的最大值 point data;//这个点所维护的向量 }//KD-tree上一个节点的定义
然后是维护一个节点的信息。
void update(int pos) { for(int i=0;i<DEM;i++) { tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i]; if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]); if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]); if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]); if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]); } tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1; }/*就是字面意思,维护子树大小,维护其子树中能到达某一维度的最大值,最小值。*/
然后是把一个子树拍扁。
void update(int pos) { for(int i=0;i<DEM;i++) { tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i]; if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]); if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]); if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]); if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]); } tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1; }/*就是字面意思,维护子树大小,维护其子树中能到达某一维度的最大值,最小值。*/
接着是将一个序列加到树上,就是把树拍扁后再挂到树上。
void update(int pos) { for(int i=0;i<DEM;i++) { tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i]; if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]); if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]); if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]); if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]); } tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1; }/*就是字面意思,维护子树大小,维护其子树中能到达某一维度的最大值,最小值。*/
查看这颗树要不要拍扁重建。
void check(int& pos,int dim) { if(tree[pos].sze*alpha<tree[tree[pos].l].sze||tree[pos].sze*alpha<tree[tree[pos].r].sze) {rev(pos,0); pos=build(1,tree[pos].sze,dim);} }// 字面意思 不平衡就拍扁重建 , alpha是替罪羊的平衡因子
插入单个点。
int insert(int pos,point data,int dim) { if(!pos) {pos=New(); tree[pos].l=tree[pos].r=0; tree[pos].data=data; update(pos); return pos;} if(data.x[dim]<=tree[pos].data.x[dem]) tree[pos].l=insert(tree[pos].l,data,dim^1); else tree[pos].r=insert(tree[pos].r,data,dim^1); update(pos); check(pos,dim); return pos; }//像平衡树一样左右看看加那边,然后挂上节点。最后check一下不让树退化
查询,就查询经典问题,给定n个点坐标,以及一个点坐标s 询问n个点中那个离s最近的点是哪个。
void query(int pos,point data) { ans=min(ans,dist(data,tree[pos].data)); //用当前点信息更新ans int dist_left=INF; int dist_right=INF; if(tree[pos].l) dist_left=get_dist(data,tree[pos].l); if(tree[pos].r) dist_right=get_dist(data,tree[pos].r); // L,R维护是查询点s到当前点左右子树所维护空间的距离 if(dist_left<dist_right) { if(dist_left<ans) query(tree[pos].l,data); if(dist_right<ans) query(tree[pos].r,data); } else { if(dist_right<ans) query(tree[pos].r,data); if(dist_left<ans) query(tree[pos].l,data); } // 以当前点为圆心,以ans为半径画圆,如果达不到左/右子树所维护的空间,就不查那边。 }
没了,真没了,写写题就好了。
附上bzoj2648代码,就是上面的问题
#include<bits/stdc++.h> #define DEM 2 #define alpha (1.130/2) #define maxn 1000010 #define INF 0x3f3f3f3f using namespace std; int n,m; int u,v; int now; int ans; int opt; int root; int points; queue<int>Q; struct point { int x[DEM]; bool operator < (const point X) const {return x[now]<X.x[now];} }one[maxn]; struct node { int l,r; int sze; int minn[DEM]; int maxx[DEM]; point data; }tree[maxn]; int dist(point,point); int get_dist(point,int); int New(); void update(int); void rev(int,int); int build(int,int,int); void check(int&,int); int insert(int,point,int); void query(int,point); int main() { cin>>n>>m; for(int i=1;i<=n;i++) cin>>one[i].x[0]>>one[i].x[1]; root=build(1,n,0); for(int i=1;i<=m;i++) { cin>>opt>>u>>v; if(opt==1) {root=insert(root,(point){u,v},0);} else {ans=INF; query(root,(point){u,v}); cout<<ans<<endl;} } } int dist(point A,point B) {return abs(A.x[0]-B.x[0])+abs(A.x[1]-B.x[1]);} int get_dist(point A,int pos) {int ret=0; for(int i=0;i<DEM;i++) ret+=max(0,A.x[i]-tree[pos].maxx[i])+max(0,tree[pos].minn[i]-A.x[i]); return ret;} int New() { if(!Q.empty()) {static int tmp; tmp=Q.front(); Q.pop(); return tmp;} else return ++points; } void update(int pos) { for(int i=0;i<DEM;i++) { tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i]; if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]); if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]); if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]); if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]); } tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1; } void rev(int pos,int num) { if(tree[pos].l) rev(tree[pos].l,num); one[tree[tree[pos].l].sze+num+1]=tree[pos].data; Q.push(pos); if(tree[pos].r) rev(tree[pos].r,tree[tree[pos].l].sze+num+1); } int build(int l,int r,int dem) { if(l>r) return 0; int mid=(l+r)>>1,pos=New(); now=dem; nth_element(one+l,one+mid,one+r+1); tree[pos].data=one[mid]; tree[pos].l=build(l,mid-1,dem^1); tree[pos].r=build(mid+1,r,dem^1); update(pos); return pos; } void check(int& pos,int dem) { if(tree[pos].sze*alpha<tree[tree[pos].l].sze||tree[pos].sze*alpha<tree[tree[pos].r].sze) {rev(pos,0); pos=build(1,tree[pos].sze,dem);} } int insert(int pos,point data,int dem) { if(!pos) {pos=New(); tree[pos].l=tree[pos].r=0; tree[pos].data=data; update(pos); return pos;} if(data.x[dem]<=tree[pos].data.x[dem]) tree[pos].l=insert(tree[pos].l,data,dem^1); else tree[pos].r=insert(tree[pos].r,data,dem^1); update(pos); check(pos,dem); return pos; } void query(int pos,point data) { ans=min(ans,dist(data,tree[pos].data)); int dist_left=INF; int dist_right=INF; if(tree[pos].l) dist_left=get_dist(data,tree[pos].l); if(tree[pos].r) dist_right=get_dist(data,tree[pos].r); if(dist_left<dist_right) { if(dist_left<ans) query(tree[pos].l,data); if(dist_right<ans) query(tree[pos].r,data); } else { if(dist_right<ans) query(tree[pos].r,data); if(dist_left<ans) query(tree[pos].l,data); } }