zoukankan      html  css  js  c++  java
  • KD-Tree 学习笔记

    这是一篇又长又烂的学习笔记,请做好及时退出的准备。

    KD-Tree 的复杂度大概是 (O(n^{1-frac{1}{k}}))
    (k) 是维度
    由于网上找不到靠谱的证明,咕了。
    会证明之后再补上。

    前置?

    • 考虑到平衡树不能做多维,kdt就是扩展到多维情况
    • 每次 (nth\_element) 的复杂度是 (O(n)) 的。
    • 类似替罪羊的想法,如果树不够平衡,直接 pia 重构
    • 考虑你删除元素不方便,据说只能打上标记啥的)
    • 但是你插入元素不改变树的大致结构 qwqwq

    建树显然是 (n log n)
    插入据说是 (n log^2 n)
    查询依旧是 (n log n) 的 qwq

    • 考虑建树


    假设最开始有这么多个点

    选一个中位数,把空间一分为二
    左边作为左儿子,右边作为右儿子

    再取一次

    我们定义初始是这样

    类似平衡树的结构

    建出来的树长成这样子

    然后像平衡树一样维护最小横坐标,纵坐标,最大横坐标,纵坐标,当前权值,当前坐标,sum值,就可以了。

    代码亦不难

    int build(int l , int r , int p) {
    	now = p ;
    	int mid = l + r >> 1 ;
    	nth_element(data + l , data + mid , data + r + 1) ; // data 是原数组 qwq 是 KDT
    	qwq[mid] = data[mid] ;
    	if(l < mid) qwq[mid].ls = build(l , mid - 1 , p ^ 1) ;
    	if(r > mid) qwq[mid].rs = build(mid + 1 , r , p ^ 1) ;
    	pushup(mid) ; return mid ;
    }
    
    • 考虑修改

    插入时要判是否平衡,如果不平衡就擦除一整棵子树并重构。(类似替罪羊树的想法

    void Erase(int x) {
      if (!x) return;
      pp[++m] = P[x], Erase(ls(x)), Erase(rs(x)), erase(x);
    }
    inline void insert(Point p) {
      int top = -1, x = root;
      if (!x) {
        pp[1] = p, root = build(1, 1, 1);
        return;
      }
      while (233) {
        if (max(sz[ls(x)], sz[rs(x)]) > sz[x] * alpha && top == -1) top = x;
        ++sz[x], cmin(L[x][0], p.x), cmax(R[x][0], p.x), cmin(L[x][1], p.y), cmax(R[x][1], p.y);
        int& y = ch[x][(tp[x] == 0) ? (!cmpx(p, P[x])) : (!cmpy(p, P[x]))];
        if (!y) {
          y = NewNode();
          L[y][0] = R[y][0] = p.x, L[y][1] = R[y][1] = p.y, sz[y] = 1, tp[y] = tp[x] ^ 1, fa[y] = x, P[y] = p;
          break;
        }
        x = y;
      }
      if (top == -1) return;
      m = 0;
      if (top == root) {
        Erase(top), root = build(1, m, 1);
        return;
      }
      int f = fa[top], &t = ch[f][(tp[f] == 0) ? (!cmpx(P[top], P[f])) : (!cmpy(P[top], P[f]))];
      Erase(top), t = build(1, m, tp[f]);
    }
    

    这样就可以了

    询问其实因题目而定的。。没什么具体做法

    int query(int x, int l0, int r0, int l1, int r1) {
      if (!x) return 0;
      if (l0 <= L[x][0] && R[x][0] <= r0 && l1 <= L[x][1] && R[x][1] <= r1) return sz[x];
      if (r0 < L[x][0] || R[x][0] < l0 || r1 < L[x][1] || R[x][1] < l1) return 0;
      return query(ls(x), l0, r0, l1, r1) + query(rs(x), l0, r0, l1, r1) +
             (l0 <= P[x].x && P[x].x <= r0 && l1 <= P[x].y && P[x].y <= r1);
    }
    

    比如这个就是二维数点查询个数的方法

    然后考虑一个东西,即维数问题
    (cdq)分治,你可以直接三维 (kdt) 直接狂 T 不止
    也可以排个序然后卡卡常数过去啥的)

    三维偏序

    #include <bits/stdc++.h>
    #define rep(i, x, y) for (register int i = x; i <= y; i++)
    using namespace std;
    using ll = long long;
    using pii = pair<int, int>;
    const static int _ = 1 << 20;
    char fin[_], *p1 = fin, *p2 = fin;
    inline char gc() { return (p1 == p2) && (p2 = (p1 = fin) + fread(fin, 1, _, stdin), p1 == p2) ? EOF : *p1++; }
    inline int read() {
      bool sign = 1;
      char c = 0;
      while (c < 48) ((c = gc()) == 45) && (sign = 0);
      int x = (c & 15);
      while ((c = gc()) > 47) x = (x << 1) + (x << 3) + (c & 15);
      return sign ? x : -x;
    }
    template <class T>
    void print(T x, char c = '
    ') {
      (x == 0) && (putchar(48)), (x < 0) && (putchar(45), x = -x);
      static char _st[100];
      int _stp = 0;
      while (x) _st[++_stp] = x % 10 ^ 48, x /= 10;
      while (_stp) putchar(_st[_stp--]);
      putchar(c);
    }
    template <class T>
    void cmax(T& x, T y) {
      (x < y) && (x = y);
    }
    template <class T>
    void cmin(T& x, T y) {
      (x > y) && (x = y);
    }
    
    const double alpha = 0.7;
    const int N = 1e5 + 10;
    int n, k;
    int ch[N][2], fa[N], sz[N], tp[N];
    int L[N][2], R[N][2];
    int st[N], top = 0;
    #define ls(x) ch[x][0]
    #define rs(x) ch[x][1]
    struct Point {
      int x, y, z, id;
      bool operator==(const Point& other) const { return x == other.x && y == other.y && z == other.z; }
    } p[N], P[N], pp[N];
    inline bool cmpx(const Point& x, const Point& y) {
      return (x.x == y.x) ? (x.y == y.y ? x.id < y.id : x.y < y.y) : x.x < y.x;
    }
    inline bool cmpy(const Point& x, const Point& y) {
      return (x.y == y.y) ? (x.x == y.x ? x.id < y.id : x.x < y.x) : x.y < y.y;
    }
    int root = 0, cnt = 0;
    inline void erase(int x) {
      st[++top] = x, ls(x) = rs(x) = sz[x] = L[x][0] = R[x][0] = L[x][1] = R[x][1] = 0;
      P[x] = { 0, 0, 0, 0 };
    }
    int m;
    inline int NewNode() { return top ? st[top--] : ++cnt; }
    int build(int l, int r, int lst) {
      if (l > r) return 0;
      int x = NewNode(), mn = 1e9, mx = -1e9;
      rep(i, l, r) cmin(mn, pp[i].x), cmax(mx, pp[i].x);
      L[x][0] = mn, R[x][0] = mx;
      mn = 1e9, mx = -1e9;
      rep(i, l, r) cmin(mn, pp[i].y), cmax(mx, pp[i].y);
      L[x][1] = mn, R[x][1] = mx, tp[x] = lst ^ 1;
      int mid = l + r >> 1;
      (lst) ? nth_element(pp + l, pp + mid, pp + r + 1, cmpx) : nth_element(pp + l, pp + mid, pp + r + 1, cmpy);
      P[x] = pp[mid], ls(x) = build(l, mid - 1, lst ^ 1), rs(x) = build(mid + 1, r, lst ^ 1);
      if (ls(x)) fa[ls(x)] = x;
      if (rs(x)) fa[rs(x)] = x;
      sz[x] = sz[ls(x)] + sz[rs(x)] + 1;
      return x;
    }
    void Erase(int x) {
      if (!x) return;
      pp[++m] = P[x], Erase(ls(x)), Erase(rs(x)), erase(x);
    }
    inline void insert(Point p) {
      int top = -1, x = root;
      if (!x) {
        pp[1] = p, root = build(1, 1, 1);
        return;
      }
      while (233) {
        if (max(sz[ls(x)], sz[rs(x)]) > sz[x] * alpha && top == -1) top = x;
        ++sz[x], cmin(L[x][0], p.x), cmax(R[x][0], p.x), cmin(L[x][1], p.y), cmax(R[x][1], p.y);
        int& y = ch[x][(tp[x] == 0) ? (!cmpx(p, P[x])) : (!cmpy(p, P[x]))];
        if (!y) {
          y = NewNode();
          L[y][0] = R[y][0] = p.x, L[y][1] = R[y][1] = p.y, sz[y] = 1, tp[y] = tp[x] ^ 1, fa[y] = x, P[y] = p;
          break;
        }
        x = y;
      }
      if (top == -1) return;
      m = 0;
      if (top == root) {
        Erase(top), root = build(1, m, 1);
        return;
      }
      int f = fa[top], &t = ch[f][(tp[f] == 0) ? (!cmpx(P[top], P[f])) : (!cmpy(P[top], P[f]))];
      Erase(top), t = build(1, m, tp[f]);
    }
    int query(int x, int l0, int r0, int l1, int r1) {
      if (!x) return 0;
      if (l0 <= L[x][0] && R[x][0] <= r0 && l1 <= L[x][1] && R[x][1] <= r1) return sz[x];
      if (r0 < L[x][0] || R[x][0] < l0 || r1 < L[x][1] || R[x][1] < l1) return 0;
      return query(ls(x), l0, r0, l1, r1) + query(rs(x), l0, r0, l1, r1) +
             (l0 <= P[x].x && P[x].x <= r0 && l1 <= P[x].y && P[x].y <= r1);
    }
    int ans[N], Cnt[N];
    
    signed main() {
    #ifdef _WIN64
      freopen("testdata.in", "r", stdin);
    #endif
      n = read(), k = read();
      rep(i, 1, n) { p[i].x = read(), p[i].y = read(), p[i].z = read(), p[i].id = i; }
      sort(p + 1, p + n + 1, [](const Point& x, const Point& y) { return x.z == y.z ? cmpx(x, y) : x.z < y.z; });
      for (int l = 1, r; l <= n; l = r + 1) {
        r = l;
        while (r < n && p[r + 1] == p[r]) insert(p[r++]);
        ans[r] = query(root, -1e9, p[r].x, -1e9, p[r].y), Cnt[ans[r]] += r - l + 1, insert(p[r]);
      }
      rep(i, 0, n - 1) print(Cnt[i]);
      return 0;
    }
    

    天使玩偶/SJY摆棋子

    #include <bits/stdc++.h>
    #define rep(i , x , y) for(register int i = (x) , _## i = ((y) + 1) ; i < _## i ; i ++)
    #define Rep(i , x , y) for(register int i = (x) , _## i = ((y) - 1) ; i > _## i ; i --)
    using namespace std ;
    //#define int long long
    using ll = long long ;
    using pii = pair < int , int > ;
    const static int _ = 1 << 20 ;
    char fin[_] , * p1 = fin , * p2 = fin ;
    inline char gc() {
    	return (p1 == p2) && (p2 = (p1 = fin) + fread(fin , 1 , _ , stdin) , p1 == p2) ? EOF : * p1 ++ ;
    }
    inline int read() {
    	bool sign = 1 ;
    	char c = 0 ;
    	while(c < 48) ((c = gc()) == 45) && (sign = 0) ;
    	int x = (c & 15) ;
    	while((c = gc()) > 47) x = (x << 1) + (x << 3) + (c & 15) ;
    	return sign ? x : -x ;
    }
    template < class T > void print(T x , char c = '
    ') {
    	(x == 0) && (putchar(48)) , (x < 0) && (putchar(45) , x = -x) ;
    	static char _st[100] ;
    	int _stp = 0 ;
    	while(x) _st[++ _stp] = x % 10 ^ 48 , x /= 10 ;
    	while(_stp) putchar(_st[_stp --]) ;
    	putchar(c) ;
    }
    template < class T > void cmax(T & x , T y) {
    	(x < y) && (x = y) ;
    }
    template < class T > void cmin(T & x , T y) {
    	(x > y) && (x = y) ;
    }
    
    
    struct KDT {
    	int x , y ;
    };
    bool cmp1(const KDT & x , const KDT & y) {
    	return x.x < y.x ;
    }
    bool cmp2(const KDT & x , const KDT & y) {
    	return x.y < y.y ;
    }
    int n , m , ans ;
    const int N = 3e6 + 10 ;
    KDT t[N] ;
    int ls[N] , rs[N] , p[N][2] , mx[N][2] , mn[N][2] ;
    
    void pushup(int x) {
    	cmax(mx[x][0] , mx[ls[x]][0]) , cmax(mx[x][0] , mx[rs[x]][0]) ;
    	cmax(mx[x][1] , mx[ls[x]][1]) , cmax(mx[x][1] , mx[rs[x]][1]) ;
    	cmin(mn[x][0] , mn[ls[x]][0]) , cmin(mn[x][0] , mn[rs[x]][0]) ;
    	cmin(mn[x][1] , mn[ls[x]][1]) , cmin(mn[x][1] , mn[rs[x]][1]) ;
    }
    int mxd = 0 , tot = 0 ;
    void ins(int & now , int x , int y , int d , int dep) {
    	if(! now) {
    		now = ++ tot ;
    		p[now][0] = x ;
    		p[now][1] = y ;
    		mx[now][0] = mn[now][0] = x ;
    		mx[now][1] = mn[now][1] = y ;
    		mxd = dep ;
    		return ;
    	}
    	if(! d && x < p[now][d]) ins(ls[now] , x , y , d ^ 1 , dep + 1) ;
    	else if(! d) ins(rs[now] , x , y , d ^ 1 , dep + 1) ;
    	else if(y < p[now][d]) ins(ls[now] , x , y , d ^ 1 , dep + 1) ;
    	else ins(rs[now] , x , y , d ^ 1 , dep + 1) ;
    	pushup(now) ;
    }
    void qry(int & dis , int x , int y , int now) {
    	dis = 0 ;
    	if(x > mx[now][0]) dis += x - mx[now][0] ;
    	if(x < mn[now][0]) dis += mn[now][0] - x ;
    	if(y > mx[now][1]) dis += y - mx[now][1] ;
    	if(y < mn[now][1]) dis += mn[now][1] - y ;
    }
    
    
    void query(int now , int x , int y) {
    	int disn = abs(x - p[now][0]) + abs(y - p[now][1]) ;
    	cmin(ans , disn) ;
    	int dl = 0x3f3f3f3f ;
    	int dr = dl ;
    	if(ls[now]) qry(dl , x , y , ls[now]) ;
    	if(rs[now]) qry(dr , x , y , rs[now]) ;
    	if(dl < dr) {
    		if(dl < ans) query(ls[now] , x , y) ;
    		if(dr < ans) query(rs[now] , x , y) ;
    	} else {
    		if(dr < ans) query(rs[now] , x , y) ;
    		if(dl < ans) query(ls[now] , x , y) ;
    	}
    }
    
    int build(int l , int r , int d) {
    	if(l > r) return 0 ;
    	int mid = l + r >> 1 ;
    	nth_element(t + l , t + mid , t + r + 1 , d ? cmp1 : cmp2) ;
    	int now = ++ tot ;
    	mx[now][0] = mn[now][0] = p[now][0] = t[mid].x ;
    	mx[now][1] = mn[now][1] = p[now][1] = t[mid].y ;
    	ls[now] = build(l , mid - 1 , d ^ 1) ;
    	rs[now] = build(mid + 1 , r , d ^ 1) ;
    	pushup(now) ;
    	return now ;
    }
    
    signed main() {
    #ifdef _WIN64
    	freopen("testdata.in" , "r" , stdin) ;
    #endif
    	memset(mn , 0x3f , sizeof(mn)) ;
    	memset(mx , 0xcf , sizeof(mx)) ;
    	n = read() ;
    	m = read() ;
    	rep(i , 1 , n) {
    		t[i].x = read() ;
    		t[i].y = read() ;
    	}
    	build(1 , n , 0) ;
    	int rt = 1 ;
    	rep(i , 1 , m) {
    		int opt = read() , x = read() , y = read() ;
    		if(opt == 1) {
    			ins(rt , x , y , 0 , 1) ;
    			t[++ n] = { x , y } ;
    			if(mxd > sqrt(tot)) tot = 0 , build(1 , n , 0) ;
    		} else {
    			ans = 0x3f3f3f3f ;
    			query(rt , x , y) ;
    			print(ans) ;
    		}
    	}
    	return 0 ;
    }
    

    巧克力王国

    // powered by c++11
    // by Isaunoya
    
    #include<bits/stdc++.h>
    #define rep(i , x , y) for(register int i = (x) ; i < (y) ; i ++)
    using namespace std ;
    using db = double ;
    using ll = long long ;
    using uint = unsigned int ;
    #define int long long
    using pii = pair < int , int > ;
    #define ve vector
    #define Tp template
    #define all(v) v.begin() , v.end()
    #define sz(v) ((int)v.size())
    #define pb emplace_back
    #define fir first
    #define sec second
    
    // the cmin && cmax
    Tp < class T > void cmax(T & x , const T & y) {
    	if(x < y) x = y ;
    }
    Tp < class T > void cmin(T & x , const T & y) {
    	if(x > y ) x = y ;
    }
    
    // sort , unique , reverse
    Tp < class T > void sort(ve < T > & v) {
    	sort(all(v)) ;
    }
    Tp < class T > void unique(ve < T > & v) {
    	sort(all(v)) ;
    	v.erase(unique(all(v)) , v.end()) ;
    }
    Tp < class T > void reverse(ve < T > & v) {
    	reverse(all(v)) ;
    }
    
    int n , m , now = 0 ;
    struct node {
    	int d[2] , ls , rs , val , sum ;
    	int mx[2] , mn[2] ;
    	bool operator < (const node & other) const {
    		return d[now] < other.d[now] ;
    	}
    } ;
    const int maxn = 5e4 + 10 ;
    node data[maxn] , qwq[maxn] ;
    void pushup(int o) {
    	int ls = qwq[o].ls , rs = qwq[o].rs ;
    	for(int i = 0 ; i < 2 ; i ++) {
    		qwq[o].mx[i] = qwq[o].mn[i] = qwq[o].d[i] ;
    		if(ls) {
    			cmin(qwq[o].mn[i] , qwq[ls].mn[i]) ;
    			cmax(qwq[o].mx[i] , qwq[ls].mx[i]) ;
    		}
    		if(rs) {
    			cmin(qwq[o].mn[i] , qwq[rs].mn[i]) ;
    			cmax(qwq[o].mx[i] , qwq[rs].mx[i]) ;
    		}
    	}
    	qwq[o].sum = qwq[o].val ;
    	if(ls) qwq[o].sum += qwq[ls].sum ;
    	if(rs) qwq[o].sum += qwq[rs].sum ;
    }
    int build(int l , int r , int p) {
    	now = p ;
    	int mid = l + r >> 1 ;
    	nth_element(data + l , data + mid , data + r + 1) ;
    	qwq[mid] = data[mid] ;
    	if(l < mid) qwq[mid].ls = build(l , mid - 1 , p ^ 1) ;
    	if(r > mid) qwq[mid].rs = build(mid + 1 , r , p ^ 1) ;
    	pushup(mid) ; return mid ;
    }
    int a , b , c ;
    int chk(int x , int y) { return x * a + y * b < c ; }
    int qry(int p) {
    	int cnt = 0 ;
    	cnt += chk(qwq[p].mn[0] , qwq[p].mn[1]) ;
    	cnt += chk(qwq[p].mn[0] , qwq[p].mx[1]) ;
    	cnt += chk(qwq[p].mx[0] , qwq[p].mn[1]) ;
    	cnt += chk(qwq[p].mx[0] , qwq[p].mx[1]) ;
    	if(cnt == 4) return qwq[p].sum ;
    	if(! cnt) return 0 ;
    	int res = 0 ;
    	if(chk(qwq[p].d[0] , qwq[p].d[1])) res += qwq[p].val ;
    	if(qwq[p].ls) res += qry(qwq[p].ls) ;
    	if(qwq[p].rs) res += qry(qwq[p].rs) ;
    	return res ;
    }
    
    int rt = 0 ;
    signed main() {
    	ios_base :: sync_with_stdio(false) ;
    	cin.tie(nullptr) , cout.tie(nullptr) ;
    // code begin.
    	cin >> n >> m ;
    	for(int i = 1 ; i <= n ; i ++) {
    		cin >> data[i].d[0] >> data[i].d[1] >> data[i].val ;
    	}
    	rt = build(1 , n , 0) ;
    	for(int i = 1 ; i <= m ; i ++) {
    		cin >> a >> b >> c ;
    		cout << qry(rt) << '
    ' ;
    	}
    	return 0 ;
    // code end.
    }
    
  • 相关阅读:
    Test-Driven Development
    单元测试之道(使用NUnit)
    IoC--structuremap
    web.config的configSections节点
    【转】理解POCO
    js的call(obj,arg)学习笔记
    css隐藏滚动条方法
    regexp学习
    asp后台拼接百度ueditor编辑器过程
    php关键词construct和static
  • 原文地址:https://www.cnblogs.com/Isaunoya/p/12243739.html
Copyright © 2011-2022 走看看