zoukankan      html  css  js  c++  java
  • KD-tree讲解

      KD-tree  讲解
     by simb351 

      应神犇junble19768的要求,来水一发KD-tree讲解。学习过程中发现关于KD-tree的资源实在是少,在OI中的应用更是少之又少。
    虽然这东西很水,随便嘴一嘴就能口胡出来。所以写一篇讲解。

     前置芝士:
    1. 基础BST姿势
    2. 替罪羊树
    3. 优良的空间感 比如能脑补出k维空间QWQ

    KD-tree的应用:
      KD-tree主要解决对K维数据的管理,比如多维偏序。但是本弱发现目前OI中KD-tree主流考法为维护二维平面中区间的信息。比如
    ION 9102 弹跳 但是好像能被菜鸡simba口胡的暴力算几卡过。


    KD-tree的原理:
      考虑从一般BST中类比过来---对于其中一个节点其左边节点的值恒小于它本身,右边反之。实际上是把所有节点的值从中间分开。
    如果说BST是对一个一维线段的分割,那么KD-tree就是对K维空间分割。最终在小的空间内统计答案。说人话就是对K维按顺序均匀分割
    查哪部分就去哪个块中查找。因为是按序分割,所以找到答案空间的时间是nlogn至nsqrtn我信你个鬼,不带O2天天被卡。为了划分空间
    ,KD-tree在第i层维护第i%k维的信息,即这一维中比它小的在左子树,大的在右子树。对于查询就像BST一样就好了。同BST,考虑
    KD-tree如何保持自身平衡。由于用方差过于优雅,此处选择替罪羊树一样的思路---拍扁重建。这样KD-tree就愉快的讲完了,撒花。
    KD-tree代码实现:
    首先是树的结点。
      
     
    struct point
    {
        int x[DIM]; //DIM☞维度 x表示一个k维向量 
        bool operator < (const point X) const
        {
            return x[now]<X.x[now];
        }
        /* 考虑分割一个维度时,为了让分割更均匀,要尽量选最中间的点 now表示当前维护维度,定义小于号来维护中间的点。*/
    }// 存储一个向量 
    struct node
    {
        int l,r;//左右子树
        int sze;//子树大小
        int minn[DIM];//此节点维护的空间中第i维的最小值
        int maxx[DIM];//此节点维护的空间中第i维的最大值
        point data;//这个点所维护的向量
    }//KD-tree上一个节点的定义

      
     
      
    然后是维护一个节点的信息。
     
    void update(int pos)
    {
        for(int i=0;i<DEM;i++)
        {
            tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
            if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
            if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
            if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
            if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
        }
        tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
    }/*就是字面意思,维护子树大小,维护其子树中能到达某一维度的最大值,最小值。*/

     
    然后是把一个子树拍扁。
    void update(int pos)
    {
        for(int i=0;i<DEM;i++)
        {
            tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
            if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
            if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
            if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
            if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
        }
        tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
    }/*就是字面意思,维护子树大小,维护其子树中能到达某一维度的最大值,最小值。*/
    接着是将一个序列加到树上,就是把树拍扁后再挂到树上。
     
    void update(int pos)
    {
        for(int i=0;i<DEM;i++)
        {
            tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
            if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
            if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
            if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
            if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
        }
        tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
    }/*就是字面意思,维护子树大小,维护其子树中能到达某一维度的最大值,最小值。*/
    查看这颗树要不要拍扁重建。
    void check(int& pos,int dim)
    {
        if(tree[pos].sze*alpha<tree[tree[pos].l].sze||tree[pos].sze*alpha<tree[tree[pos].r].sze)
        {rev(pos,0); pos=build(1,tree[pos].sze,dim);}
    }// 字面意思 不平衡就拍扁重建 , alpha是替罪羊的平衡因子

    插入单个点。
     
    int insert(int pos,point data,int dim)
    {
        if(!pos) {pos=New(); tree[pos].l=tree[pos].r=0; tree[pos].data=data; update(pos); return pos;}
        if(data.x[dim]<=tree[pos].data.x[dem]) tree[pos].l=insert(tree[pos].l,data,dim^1);
        else tree[pos].r=insert(tree[pos].r,data,dim^1);
        update(pos); check(pos,dim); return pos;
    }//像平衡树一样左右看看加那边,然后挂上节点。最后check一下不让树退化
    查询,就查询经典问题,给定n个点坐标,以及一个点坐标s 询问n个点中那个离s最近的点是哪个。
     
    void query(int pos,point data)
    {
        ans=min(ans,dist(data,tree[pos].data)); //用当前点信息更新ans
        int dist_left=INF; int dist_right=INF;
        if(tree[pos].l) dist_left=get_dist(data,tree[pos].l);
        if(tree[pos].r) dist_right=get_dist(data,tree[pos].r);
        // L,R维护是查询点s到当前点左右子树所维护空间的距离
        if(dist_left<dist_right)
        {
            if(dist_left<ans) query(tree[pos].l,data);
            if(dist_right<ans) query(tree[pos].r,data);
        }
        else 
        {
            if(dist_right<ans) query(tree[pos].r,data);
            if(dist_left<ans) query(tree[pos].l,data);
        }
        // 以当前点为圆心,以ans为半径画圆,如果达不到左/右子树所维护的空间,就不查那边。
    }

    没了,真没了,写写题就好了。
    附上bzoj2648代码,就是上面的问题
     
    #include<bits/stdc++.h>
    #define DEM 2
    #define alpha (1.130/2)
    #define maxn 1000010
    #define INF 0x3f3f3f3f
    using namespace std;
    int n,m;
    int u,v;
    int now;
    int ans;
    int opt;
    int root;
    int points;
    queue<int>Q;
    struct point
    {
        int x[DEM]; 
        bool operator < (const point X) const {return x[now]<X.x[now];}
    }one[maxn];
    struct node
    {
        int l,r;
        int sze;
        int minn[DEM];
        int maxx[DEM];
        point data;
    }tree[maxn];
    int dist(point,point);
    int get_dist(point,int);
    int New();
    void update(int);
    void rev(int,int);
    int build(int,int,int);
    void check(int&,int);
    int insert(int,point,int);
    void query(int,point);
    int main()
    {
        cin>>n>>m;
        for(int i=1;i<=n;i++) cin>>one[i].x[0]>>one[i].x[1];
        root=build(1,n,0);
        for(int i=1;i<=m;i++)
        {
            cin>>opt>>u>>v;
            if(opt==1) {root=insert(root,(point){u,v},0);}
            else {ans=INF; query(root,(point){u,v}); cout<<ans<<endl;}  
        } 
    }
    int dist(point A,point B) {return abs(A.x[0]-B.x[0])+abs(A.x[1]-B.x[1]);}
    int get_dist(point A,int pos) {int ret=0; for(int i=0;i<DEM;i++) ret+=max(0,A.x[i]-tree[pos].maxx[i])+max(0,tree[pos].minn[i]-A.x[i]); return ret;}
    int New()
    {
        if(!Q.empty()) {static int tmp; tmp=Q.front(); Q.pop(); return tmp;}
        else return ++points;
    }
    void update(int pos)
    {
        for(int i=0;i<DEM;i++)
        {
            tree[pos].maxx[i]=tree[pos].minn[i]=tree[pos].data.x[i];
            if(tree[pos].l) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].l].minn[i]);
            if(tree[pos].r) tree[pos].minn[i]=min(tree[pos].minn[i],tree[tree[pos].r].minn[i]);
            if(tree[pos].l) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].l].maxx[i]);
            if(tree[pos].r) tree[pos].maxx[i]=max(tree[pos].maxx[i],tree[tree[pos].r].maxx[i]);
        }
        tree[pos].sze=tree[tree[pos].l].sze+tree[tree[pos].r].sze+1;
    }
    void rev(int pos,int num)
    {
        if(tree[pos].l) rev(tree[pos].l,num);
        one[tree[tree[pos].l].sze+num+1]=tree[pos].data; Q.push(pos);
        if(tree[pos].r) rev(tree[pos].r,tree[tree[pos].l].sze+num+1);
    }
    int build(int l,int r,int dem)
    {
        if(l>r) return 0;
        int mid=(l+r)>>1,pos=New();
        now=dem; nth_element(one+l,one+mid,one+r+1); tree[pos].data=one[mid]; 
        tree[pos].l=build(l,mid-1,dem^1); tree[pos].r=build(mid+1,r,dem^1);
        update(pos); return pos;
    }
    void check(int& pos,int dem)
    {
        if(tree[pos].sze*alpha<tree[tree[pos].l].sze||tree[pos].sze*alpha<tree[tree[pos].r].sze)
        {rev(pos,0); pos=build(1,tree[pos].sze,dem);}
    }
    int insert(int pos,point data,int dem)
    {
        if(!pos) {pos=New(); tree[pos].l=tree[pos].r=0; tree[pos].data=data; update(pos); return pos;}
        if(data.x[dem]<=tree[pos].data.x[dem]) tree[pos].l=insert(tree[pos].l,data,dem^1);
        else tree[pos].r=insert(tree[pos].r,data,dem^1);
        update(pos); check(pos,dem); return pos;
    }
    void query(int pos,point data)
    {
        ans=min(ans,dist(data,tree[pos].data));
        int dist_left=INF; int dist_right=INF;
        if(tree[pos].l) dist_left=get_dist(data,tree[pos].l);
        if(tree[pos].r) dist_right=get_dist(data,tree[pos].r);
        if(dist_left<dist_right)
        {
            if(dist_left<ans) query(tree[pos].l,data);
            if(dist_right<ans) query(tree[pos].r,data);
        }
        else 
        {
            if(dist_right<ans) query(tree[pos].r,data);
            if(dist_left<ans) query(tree[pos].l,data);
        }
    }

     
     
  • 相关阅读:
    Python requests“Max retries exceeded with url” error
    命令行链接mongo、redis、mysql
    python 删除字典某个key(键)及对应值
    python标准模块(二)
    python标准模块(一)
    格式化输出
    LeetCode----1. Two Sum
    文件操作(初阶)
    python函数基础
    python3内置函数
  • 原文地址:https://www.cnblogs.com/simba351/p/11504169.html
Copyright © 2011-2022 走看看