zoukankan      html  css  js  c++  java
  • [基本操作] kd 树

    概念就不说了吧,网上教程满天飞

    学了半天才知道,kd 树实质上只干了两件事情:

    1.快速定位一个点 / 矩形

    2.有理有据地优化暴力

    第一点大概是可以来做二维平面上给点/矩形打标记的问题

    第二点大概是平面最远点对?

    bzoj1941 Hide and Seek

    求每个点除自己以外的最近点和最远点

    sol:

    kd 树优化暴力,对于暴力,考虑这样一个剪枝:如果一个点在某一维上隔的太远,就不搜比它远的了

    把这个东西放到 kd 树上就可以了

    #include<bits/stdc++.h>
    #define LL long long
    using namespace std;
    inline int read()
    {
        int x = 0,f = 1;char ch = getchar();
        for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f;
        for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
        return x * f;
    }
    const int maxn = 500010,inf = 1e9;
    #define L ps[x].lc
    #define R ps[x].rc 
    int n,dem,root;
    struct Node
    {
        int x[2],lc,rc,mx[2],mn[2];
        bool operator < (const Node &b)const{return x[dem] < b.x[dem];}
        friend int dis(Node a,Node b){return abs(a.x[0] - b.x[0]) + abs(a.x[1] - b.x[1]);}
    }ps[maxn];
    int qx[maxn],qy[maxn];
    void maintain(int x)
    {
        for(int i=0;i<2;i++)
        {
            ps[x].mn[i] = ps[x].mx[i] = ps[x].x[i];
            if(L)ps[x].mx[i] = max(ps[x].mx[i],ps[L].mx[i]);
            if(L)ps[x].mn[i] = min(ps[x].mn[i],ps[L].mn[i]);
            if(R)ps[x].mx[i] = max(ps[x].mx[i],ps[R].mx[i]);
            if(R)ps[x].mn[i] = min(ps[x].mn[i],ps[R].mn[i]);
        }
    }
    void build(int &x,int l,int r,int cur)
    {
        if(l > r)return;
        dem = cur;
        int mid = (l + r) >> 1;x = mid;
        nth_element(ps + l,ps + mid,ps + r + 1);
        build(L,l,mid - 1,cur ^ 1);build(R,mid + 1,r,cur ^ 1);
        maintain(x);
    }
    int calmin(Node p,Node cur)
    {
        int ans = 0;
        for(int i=0;i<2;i++)
        {
            ans += max(p.x[i] - cur.mx[i],0);
            ans += max(-(p.x[i] - cur.mn[i]),0);
        }
        return ans;
    }
    int calmax(Node p,Node cur)
    {
        int ans = 0;
        for(int i=0;i<2;i++)
            ans += max(abs(p.x[i] - cur.mx[i]),abs(p.x[i] - cur.mn[i]));
        return ans;
    }
    int ans;
    void querymx(Node cur,int x)
    {
        ans = max(ans,dis(cur,ps[x]));
        int dl = -inf,dr = -inf;
        if(L)dl = calmax(cur,ps[L]);if(R)dr = calmax(cur,ps[R]);
        if(dl > dr)
        {
            if(dl > ans)querymx(cur,L);
            if(dr > ans)querymx(cur,R);
        }
        else
        {
            if(dr > ans)querymx(cur,R);
            if(dl > ans)querymx(cur,L);
        }
    }
    void querymn(Node cur,int x)
    {
        int tmp = dis(cur,ps[x]);
        if(tmp)ans = min(ans,tmp);
        int dl = inf,dr = inf;
        if(L)dl = calmin(cur,ps[L]);if(R)dr = calmin(cur,ps[R]);
        if(dl < dr)
        {
            if(dl < ans)querymn(cur,L);
            if(dr < ans)querymn(cur,R);
        }
        else
        {
            if(dr < ans)querymn(cur,R);
            if(dl < ans)querymn(cur,L);
        }
    }
    int query(int x,int y,int type)
    {
        Node cur;
        cur.x[0] = x;cur.x[1] = y;
        if(type == 0)ans = inf,querymn(cur,root);
        else ans = -inf,querymx(cur,root);
        return ans;
    }
    int main()
    {
        n = read();
        for(int i=1;i<=n;i++)
        {
            qx[i] = ps[i].x[0] = read();
            qy[i] = ps[i].x[1] = read();        
        }build(root,1,n,0);
        int ret = 2147483233;
        for(int i=1;i<=n;i++)
        {
            int mn = query(qx[i],qy[i],0),mx = query(qx[i],qy[i],1);
            ret = min(ret,mx - mn);
        }cout<<ret<<endl;
    }
    View Code

    bzoj4520 K 远点对

    给 n 个点的坐标,求第 k 远的点对距离

    sol:

    暴力的话,用一个长度只有 k 的优先队列维护前 k 远距离

    kd 树的话就是用这个做估价函数,如果当前搜到的点进不了队,就不搜了

    #include<bits/stdc++.h>
    #define ll long long
    using namespace std;
    int n,k;
    namespace KD_Tree
    {
        int dim;
        priority_queue<ll,vector<ll>,greater<ll> >q;
        void mdf(ll x){q.push(x);q.pop();}
        struct Point
        {
            ll P[2];
            int ls,rs;
            ll mn[2],mx[2];
        }c[100010];
        bool cmp(Point a,Point b){return a.P[dim]<b.P[dim];}
        ll sqr(ll x) {return x*x;}
        ll calc(Point x,Point y){return sqr(x.P[0]-y.P[0])+sqr(x.P[1]-y.P[1]);}
        void update(int x)
        {
            c[x].mn[0]=c[x].mx[0]=c[x].P[0];
            c[x].mn[1]=c[x].mx[1]=c[x].P[1];
            for(int i=0;i<2;i++)
            {
                if(c[x].ls)
                {
                    c[x].mn[i]=min(c[x].mn[i],c[c[x].ls].mn[i]);
                    c[x].mx[i]=max(c[x].mx[i],c[c[x].ls].mx[i]);
                }
                if(c[x].rs)
                {
                    c[x].mn[i]=min(c[x].mn[i],c[c[x].rs].mn[i]);
                    c[x].mx[i]=max(c[x].mx[i],c[c[x].rs].mx[i]);
                }
            }
        }
        int build(int l,int r,int w)
        {
            if(l>r) return 0;
            if(l==r) {update(l);return l;}
            dim=w;int mid=(l+r)>>1;
            nth_element(c+l,c+mid,c+r+1,cmp);
            c[mid].ls=build(l,mid-1,w^1);c[mid].rs=build(mid+1,r,w^1);
            update(mid);
            return mid;
        }
        ll cheque(int x,Point p)
        {
            ll tx=max(sqr(c[x].mn[0]-p.P[0]),sqr(c[x].mx[0]-p.P[0]));
            ll ty=max(sqr(c[x].mn[1]-p.P[1]),sqr(c[x].mx[1]-p.P[1]));
            return tx+ty;
        }
        void qury(int x,Point p)
        {
            mdf(calc(c[x],p));
            ll dl=0,dr=0;
            if(c[x].ls) dl=cheque(c[x].ls,p);
            if(c[x].rs) dr=cheque(c[x].rs,p);
            if(dl>dr)
            {
                if(dl>q.top()) qury(c[x].ls,p);
                if(dr>q.top()) qury(c[x].rs,p);
            }
            else
            {
                if(dr>q.top()) qury(c[x].rs,p);
                if(dl>q.top()) qury(c[x].ls,p);
            }
        }
    }
    using namespace KD_Tree;
    int main()
    {
        scanf("%d%d",&n,&k);
        for(int i=1;i<=2*k;i++) q.push(0);
        for(int i=1;i<=n;i++)
        {
            scanf("%lld%lld",&c[i].P[0],&c[i].P[1]);
        }
        int rt=build(1,n,0);
        for(int i=1;i<=n;i++)
        {
            qury(rt,c[i]);
        }
        printf("%lld
    ",q.top());
    }
    View Code

    bzoj2683 & bzoj4066 简单题

    矩形加,矩形求和

    sol:

    这道题要支持在 kd 树里插入一个点,类似二叉搜索树,走到儿子然后把新点加进去就可以了

    注意定期重构

    重构的时候注意走到非法节点就把它赋成 0

    #include<bits/stdc++.h>
    #define LL long long
    using namespace std;
    inline int read()
    {
        int x = 0,f = 1;char ch = getchar();
        for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f;
        for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
        return x * f;
    }
    const int maxn = 200010,inf = 1e9;
    #define L (ps[x].lc)
    #define R (ps[x].rc) 
    int n,dem,root,ToT;
    struct Node
    {
        int x[2],lc,rc,mx[2],mn[2],val,sum;
        bool operator < (const Node &b)const{return x[dem] < b.x[dem];}
        bool operator == (const Node& t) const { return x[0] == t.x[0] && x[1] == t.x[1]; }
        //friend int dis(Node a,Node b){return abs(a.x[0] - b.x[0]) + abs(a.x[1] - b.x[1]);}
    }ps[maxn];
    void maintain(int x)
    {
        for(int i=0;i<2;i++)
        {
            ps[x].mn[i] = ps[x].mx[i] = ps[x].x[i];
            //cout<<L<<" "<<R<<endl;
            if(L)ps[x].mx[i] = max(ps[x].mx[i],ps[L].mx[i]);
            if(L)ps[x].mn[i] = min(ps[x].mn[i],ps[L].mn[i]);
            if(R)ps[x].mx[i] = max(ps[x].mx[i],ps[R].mx[i]);
            if(R)ps[x].mn[i] = min(ps[x].mn[i],ps[R].mn[i]);
        }
        ps[x].sum = ps[x].val;
        if(L)ps[x].sum += ps[L].sum;
        if(R)ps[x].sum += ps[R].sum;
    }
    void build(int &x,int l,int r,int cur)
    {
        if(l > r){x = 0;return;}
        dem = cur;
        int mid = (l + r) >> 1;x = mid;
        nth_element(ps + l,ps + mid,ps + r + 1);
        build(L,l,mid - 1,cur ^ 1);build(R,mid + 1,r,cur ^ 1);
        maintain(x);
    }
    void insert(int &x,Node cur,int cur_d)
    {
        if(!x)
        {
            x = ++ToT;
            ps[x] = cur;
            maintain(x);
            return;
        }
        if(ps[x] == cur)
        {
            ps[x].val += cur.val;
            ps[x].sum += cur.val;
            maintain(x);
            return; 
        }
        if(cur.x[cur_d] < ps[x].x[cur_d])insert(L,cur,cur_d ^ 1);
        else insert(R,cur,cur_d ^ 1);
        maintain(x);
    }
    inline int all_in(Node cur_1,Node cur_2,int x){return (cur_1.x[0] <= ps[x].mn[0]) && (ps[x].mx[0] <= cur_2.x[0]) && (cur_1.x[1] <= ps[x].mn[1]) && (ps[x].mx[1] <= cur_2.x[1]);}
    inline int has_node(Node cur_1,Node cur_2,int x){return !(ps[x].mx[0] < cur_1.x[0] || ps[x].mn[0] > cur_2.x[0] || ps[x].mx[1] < cur_1.x[1] || ps[x].mn[1] > cur_2.x[1]);}
    inline int query(Node cur_1,Node cur_2,int x)
    {
        if(!x)return 0;
        int ans = 0;
        if(all_in(cur_1,cur_2,L))ans += ps[L].sum;
        else if(has_node(cur_1,cur_2,L)) ans += query(cur_1,cur_2,L);
        if(all_in(cur_1,cur_2,R))ans += ps[R].sum;
        else if(has_node(cur_1,cur_2,R)) ans += query(cur_1,cur_2,R);
        int nx = ps[x].x[0],ny = ps[x].x[1];
        if(cur_1.x[0] <= nx && nx <= cur_2.x[0] && cur_1.x[1] <= ny && ny <= cur_2.x[1])ans += ps[x].val;
        return ans;
    }Node cur_1,cur_2; 
    int main()
    {
        //nodes[0].val = nodes[0].sum = 0;
        n = read();int lastans = 0;
        n = read();
        while(n != 3)
        {
            if(n == 1)
            {
                cur_1.x[0] = read() ^ lastans;
                cur_1.x[1] = read() ^ lastans;
                cur_1.val = read() ^ lastans;
                insert(root,cur_1,1);if(ToT % 1000 == 0)build(root,1,ToT,1);
            }
            if(n == 2)
            {
                cur_1.x[0] = read() ^ lastans;cur_1.x[1] = read() ^ lastans;
                cur_2.x[0] = read() ^ lastans;cur_2.x[1] = read() ^ lastans;
                printf("%d
    ",lastans = query(cur_1,cur_2,root));
            }
            n = read();
        }
    }
    View Code

    bzoj4170 & 2989 数列

    给一个二维平面,每次插入一个单点,查询与给定点曼哈顿距离不超过 $d$ 的点的数量

    sol:

    实质上是查询一个斜 45 度的正方形内有多少个点,这个东西直接做复杂度好像是错的,旋转一下坐标系即可

    (然而你为啥不写 CDQ 呢

    #include<bits/stdc++.h>
    #define LL long long
    using namespace std;
    inline int read()
    {
        int x = 0,f = 1;char ch = getchar();
        for(;!isdigit(ch);ch = getchar())if(ch == '-') f = -f;
        for(;isdigit(ch);ch = getchar())x = 10 * x + ch - '0';
        return x * f;
    }
    const int maxn = 150010;
    int n,q,dem,ToT,root;
    int a[maxn];
    #define L tr[x].ls
    #define R tr[x].rs 
    struct Node
    {
        int x[2],ls,rs,mx[2],mn[2],val,sum;
        bool operator < (const Node &b)const{return x[dem] < b.x[dem];}
        bool operator == (const Node &b)const{return (x[1] == b.x[1]) && (x[0] == b.x[0]);}
    }tr[maxn];
    void maintain(int x)
    {
        for(int i=0;i<2;i++)
        {
            tr[x].mx[i] = tr[x].mn[i] = tr[x].x[i];
            if(L)tr[x].mx[i] = max(tr[L].mx[i],tr[x].mx[i]);
            if(L)tr[x].mn[i] = min(tr[L].mn[i],tr[x].mn[i]);
            if(R)tr[x].mx[i] = max(tr[R].mx[i],tr[x].mx[i]);
            if(R)tr[x].mn[i] = min(tr[R].mn[i],tr[x].mn[i]); 
        }tr[x].sum = tr[x].val;
        if(L)tr[x].sum += tr[L].sum;
        if(R)tr[x].sum += tr[R].sum;
    }
    void build(int &x,int l,int r,int cur_d)
    {
        if(l > r){x = 0;return;}
        dem = cur_d;
        int mid = (l + r) >> 1;
        nth_element(tr + l,tr + mid,tr + r + 1);
        x = mid;//tr[x].val = 1;
        cur_d = cur_d ^ 1;
        build(L,l,mid - 1,cur_d);build(R,mid + 1,r,cur_d);
        maintain(x);
    }
    void insert(int &x,Node cur,int cur_d)
    {
        if(!x)
        {
            x = ++ToT;
            tr[x] = cur;
            maintain(x);
            return;
        }
        if(tr[x] == cur)
        {
            tr[x].val += cur.val;
            tr[x].sum += cur.val;
            maintain(x);
            return; 
        }
        if(cur.x[cur_d] < tr[x].x[cur_d])insert(L,cur,cur_d ^ 1);
        else insert(R,cur,cur_d ^ 1);
        maintain(x);
    }
    inline int all_in(Node cur_1,Node cur_2,int x){return (cur_1.x[0] <= tr[x].mn[0]) && (tr[x].mx[0] <= cur_2.x[0]) && (cur_1.x[1] <= tr[x].mn[1]) && (tr[x].mx[1] <= cur_2.x[1]);}
    inline int has_node(Node cur_1,Node cur_2,int x){return !(tr[x].mx[0] < cur_1.x[0] || tr[x].mn[0] > cur_2.x[0] || tr[x].mx[1] < cur_1.x[1] || tr[x].mn[1] > cur_2.x[1]);}
    inline int query(Node cur_1,Node cur_2,int x)
    {
        if(!x)return 0;
        int ans = 0;
        if(all_in(cur_1,cur_2,L))ans += tr[L].sum;
        else if(has_node(cur_1,cur_2,L)) ans += query(cur_1,cur_2,L);
        if(all_in(cur_1,cur_2,R))ans += tr[R].sum;
        else if(has_node(cur_1,cur_2,R)) ans += query(cur_1,cur_2,R);
        int nx = tr[x].x[0],ny = tr[x].x[1];
        if(cur_1.x[0] <= nx && nx <= cur_2.x[0] && cur_1.x[1] <= ny && ny <= cur_2.x[1])ans += tr[x].val;
        return ans;
    }Node cur_1,cur_2;
    int main()
    {
        n = read();q = read();
        for(int i=1;i<=n;i++)a[i] = read();
        for(int i=1;i<=n;i++)
        {
            cur_1.x[0] = i - a[i],cur_1.x[1] = i + a[i],cur_1.val = 1;
            insert(root,cur_1,1);
        }build(root,1,n,1);
        char opt[50];
        while(q--)
        {
            scanf("%s",opt + 1);
            if(opt[1] == 'M')
            {
                int x = read(),y = read();
                a[x] = y;cur_1.x[0] = x - y,cur_1.x[1] = x + y,cur_1.val = 1;
                insert(root,cur_1,1);if(ToT % 2000 == 0)build(root,1,ToT,1);
            }
            else
            {
                int x = read(),y = read();
                cur_1.x[0] = x - a[x] - y;
                cur_2.x[1] = x + a[x] + y;
                cur_2.x[0] = x - a[x] + y;
                cur_1.x[1] = x + a[x] - y;
                //if(cur_1.x[0] > cur_2.x[0])swap(cur_1.x[0],cur_2.x[0]);
                //if(cur_1.x[1] > cur_2.x[1])swap(cur_1.x[1],cur_2.x[1]);
                printf("%d
    ",query(cur_1,cur_2,root));
            }
        }
    }
    View Code
  • 相关阅读:
    python ping监控
    MongoDB中一些命令
    进制转换(十进制转十六进制 十六进制转十进制)
    通过ssh建立点对点的隧道,实现两个子网通信
    linux环境下的各种后台执行
    python requests请求指定IP的域名
    不需要修改/etc/hosts,curl直接解析ip请求域名
    MongoDB数据update的坑
    windows平台使用Microsoft Visual C++ Compiler for Python 2.7编译python扩展
    rabbitmq问题之HTTP access denied: user 'guest'
  • 原文地址:https://www.cnblogs.com/Kong-Ruo/p/10171257.html
Copyright © 2011-2022 走看看