zoukankan      html  css  js  c++  java
  • KD-tree学习笔记(超全!)

    因为之前找不到全的博客,唯一的一篇码风比较毒瘤。。。

    所以我就来写了

    K-D树

    大概是高维二叉树吧

    每次按一个维度对超空间内的点进行二分划分

    树上存左右节点和这个节点所代表的的点

    更新信息

    我们保存几个信息:

    1. size 在重构的时候有用
    2. min[2],max[2],,就是子树中每个维度的值的最值,即处理出当前节点所代表的空间
    3. 题目中的其他信息,比如区间总权值
    void push_up(int now){
    	int l=ls[now],r=rs[now];t[now].sz=t[l].sz+t[r].sz+1;t[now].sum=t[l].sum+t[r].sum+t[now].c.cnt;
    	for(register int i=0;i<=1;i++){
    		t[now].mi[i]=t[now].mx[i]=t[now].c.x[i];
    		if(l) t[now].mi[i]=min(t[now].mi[i],t[l].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[l].mx[i]);
    		if(r) t[now].mi[i]=min(t[now].mi[i],t[r].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[r].mx[i]);
    	}
    }
    

    建树

    递归进行,每次选择一个维度进行划分,每次(O(N))共大约(log n)

    注意应用(nth)_(element)函数,要定义point之间的比较符号

    int operator < (point a,point b){
    	return a.x[D]<b.x[D];
    }
    inline int build(int l,int r,int d){
    	if(l>r) return 0;
    	int now=newnode(),mid=(l+r)>>1;
    	D=d,nth_element(p+l,p+mid,p+r+1);
    	t[now].c=p[mid],ls[now]=build(l,mid-1,d^1),rs[now]=build(mid+1,r,d^1);
    	push_up(now); return now;
    }
    

    插入

    另一种形式的建树。。。

    就是找到对应的区间加点就行了,跟平衡树差不多,注意push_up

    inline void insert(int &now,point p,int d){
    	if(!now){
    		now=newnode();ls[now]=rs[now]=0;t[now].c=p;push_up(now);return;
    	}
    	if(p.x[d]<=t[now].c.x[d]) insert(ls[now],p,d^1);
        else insert(rs[now],p,d^1);
    	push_up(now);check(now,d);
    }
    

    查询

    每次到达一个节点,首先判断这个节点是不是被查询区间完全包含

    如果是,统计答案并退出

    然后分三部分查询:本节点,左右儿子区间

    本节点直接判断,左右儿子区间判断是否和查询区间有交集,有就递归

    有论文证明了矩形操作里面复杂度是(n ^{frac{k-1}{k}})的,k是维度数

    这个复杂度很大,一般用在k=2的时候

    对于高维我们可以排序去掉一维或者CDQ分治

    struct sqr{
    	int x1,x2,y1,y2;
    }q;
    int chkin(int now,sqr tp){
    	return (!(t[now].mx[0]<tp.x1||t[now].mi[0]>tp.x2||t[now].mx[1]<tp.y1||t[now].mi[1]>tp.y2));
    }
    int totalin(int now,sqr tp){
    	return (t[now].mx[0]<=tp.x2&&t[now].mi[0]>=tp.x1&&t[now].mx[1]<=tp.y2&&t[now].mi[1]>=tp.y1);
    }
    int ptin(point a,sqr b){
    	return (b.x2>=a.x[0]&&b.x1<=a.x[0]&&b.y1<=a.x[1]&&b.y2>=a.x[1]);
    }
    inline void query(int now,sqr tp){
    	if(!now) return 0;
    	int re=0;
    	if(totalin(now,tp)){
    		ans+=t[now].sum;return;
    	}
    	if(ptin(t[now].c,tp)) ans+=t[now].c.cnt;
    	int l=ls[now],r=rs[now];
    	if(chkin(l,tp)) query(l,tp);
        if(chkin(r,tp)) query(r,tp);
    	return re;
    }
    

    k远/近询问

    构造一个小/大根堆,先push几个0/inf

    然后query树更新就行了,用估价函数来判断区间包含和剪枝(决定搜索顺序

    复杂度不稳定,没有保证,需要卡常

    下面是K远点(曼哈顿距离,我转化成平方避免小数)查询的代码

    int dissqr(point tp,int a){
    	int di=0; 
    	for(int i=0;i<=1;i++){
    		int nd=0;
    		if(tp.x[i]<t[a].mi[i]) nd=t[a].mx[i]-tp.x[i]; 
             else if(tp.x[i]>t[a].mx[i]) nd=tp.x[i]-t[a].mi[i];
    		else nd=max(tp.x[i]-t[a].mi[i],t[a].mx[i]-tp.x[i]);
    		di+=nd*nd; 
    	}
    	return di;
    }
    void query(int now,point tp){
    	int di=get_dis(t[now].c,tp);if(di>q.top()) q.pop(),q.push(di);
    	int l=ls[now],r=rs[now],dl,dr;
    	dl=l?dissqr(tp,l):-inf,dr=r?dissqr(tp,r):-inf;
    	if(dl>dr){
    		if(dl>q.top()) query(l,tp);
    		if(dr>q.top()) query(r,tp);
    	}else{
    		if(dr>q.top()) query(r,tp);
    		if(dl>q.top()) query(l,tp);
    	}
    }
    

    重构

    每次insert的时候check一下就可以啦

    参考替罪羊树,设一个重构参数

    还有就是注意回收节点内存,开个栈

    #define alpha 0.75
    int rub[N],top;
    inline int newnode(){
    	if(top) return rub[top--];
    	else return ++tot;
    }
    inline void clear(int now,int pos){
    	if(ls[now]) clear(ls[now],pos);
    	p[pos+t[ls[now]].sz+1]=t[now].c,rub[++top]=now;
    	if(rs[now]) clear(rs[now],pos+t[ls[now]].sz+1);
    }
    inline void check(int &now,int d){
    	if(alpha*(double)(t[now].sz)<(double)(t[ls[now]].sz)||alpha*(double)(t[now].sz)<(double)(t[rs[now]].sz)){
    		clear(now,0);now=build(1,t[now].sz,d);
    	}
    }
    inline void insert(int &now,point p,int d){
    	...
    	check(now,d);
    }
    

    完整模板

    K远点对

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cstring>
    #include<queue>
    #define inf 192608170000000ll 
    #define ll long long
    using namespace std;
    long long read(){
    	long long x=0,pos=1;char ch=getchar();
    	for(;!isdigit(ch);ch=getchar()) if(ch=='-') pos=0;
    	for(;isdigit(ch);ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
    	return pos?x:-x;
    } 
    const long long N = 200001;
    long long n,k;
    struct point{
    	long long x[2];
    }p[N];
    struct cmp{
    	long long operator()(long long a,long long b){
    		return a>b;
    	}
    };
    priority_queue<long long,vector<long long>,cmp>q;
    struct node{
    	long long mi[2],mx[2],sz;point c;
    }t[N];
    long long rt,D,rs[N],ls[N];
    long long operator < (point a,point b){
    	return a.x[D]<b.x[D];
    }
    void push_up(long long now){
    	long long l=ls[now],r=rs[now];
    	t[now].sz=t[l].sz+t[r].sz+1;
    	for(register long long i=0;i<=1;i++){
    		t[now].mi[i]=t[now].mx[i]=t[now].c.x[i];
    		if(l) t[now].mi[i]=min(t[now].mi[i],t[l].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[l].mx[i]);
    		if(r) t[now].mi[i]=min(t[now].mi[i],t[r].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[r].mx[i]);
    	}
    }
    long long tot=0; 
    void build(long long &now,long long l,long long r,long long d){
    	if(l>r) return;
    	now=++tot;
    	long long mid=(l+r)>>1;
    	D=d;nth_element(p+l,p+mid,p+r+1);
    	t[now].c=p[mid];
    	build(ls[now],l,mid-1,d^1);
    	build(rs[now],mid+1,r,d^1);
    	push_up(now);
    }
    inline long long abs(long long a){
    	return a>0?a:-a;
    }
    long long get_dis(point a,point b){
    	return (a.x[0]-b.x[0])*(a.x[0]-b.x[0])+(a.x[1]-b.x[1])*(a.x[1]-b.x[1]);
    }
    long long dissqr(point tp,long long a){
    	long long di=0; 
    	for(long long i=0;i<=1;i++){
    		long long nd=0;
    		if(tp.x[i]<t[a].mi[i]){
    			nd=t[a].mx[i]-tp.x[i]; 
    		}else if(tp.x[i]>t[a].mx[i]){
    			nd=tp.x[i]-t[a].mi[i];
    		}else nd=max(tp.x[i]-t[a].mi[i],t[a].mx[i]-tp.x[i]);
    		di+=nd*nd; 
    	}
    	return di;
    }
    void query(long long now,point tp){
    	long long di=get_dis(t[now].c,tp);if(di>q.top()) q.pop(),q.push(di);
    	long long l=ls[now],r=rs[now],dl,dr;
    	dl=l?dissqr(tp,l):-inf,dr=r?dissqr(tp,r):-inf;
    	if(dl>dr){
    		if(dl>q.top()) query(l,tp);
    		if(dr>q.top()) query(r,tp);
    	}else{
    		if(dr>q.top()) query(r,tp);
    		if(dl>q.top()) query(l,tp);
    	}
    }
    int main(){
    	n=read(),k=read();
    	for(register long long i=1;i<=n;i++){
    		p[i].x[0]=read();
    		p[i].x[1]=read();
    	}
    	build(rt,1,n,0);
    	for(register long long i=1;i<=2*k;i++){
    		q.push(0);
    	}
    	for(long long i=1;i<=n;i++){
    		query(rt,p[i]);
    	}
    	/*putchar(10);
    	for(long long i=1;i<=n;i++){
    		printf("%d %d
    ",p[i].x[0],p[i].x[1]);
    	}
    	for(long long i=1;i<=n;i++){
    		for(long long j=1;j<=n;j++){
    			printf("%d ",get_dis(p[i],p[j]));
    		}
    		putchar(10);
    	}*/
    	printf("%lld",q.top());
    	return 0;
    }
    

    MOKIA(三维数点)

    我偏不写CDQ

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cstring>
    #define inf 1926081700
    #define alpha 0.75
    #define ll long long 
    using namespace std;
    int read(){
    	int x=0,pos=1;char ch=getchar();
    	for(;!isdigit(ch);ch=getchar()) if(ch=='-') pos=0;
    	for(;isdigit(ch);ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
    	return pos?x:-x;
    } 
    const int N = 400001;
    int n,k,ans,lnk[N],lst[N],rub[N];
    struct sqr{
    	int x1,x2,y1,y2;
    }q;
    struct point{
    	int x[2],cnt;
    }p[N],pn;
    struct node{
    	int mi[2],mx[2],sz,sum;point c;
    }t[N];
    int rt,D,rs[N],ls[N],top,tot;
    int operator < (point a,point b){
    	return a.x[D]<b.x[D];
    }
    void push_up(int now){
    	int l=ls[now],r=rs[now];t[now].sz=t[l].sz+t[r].sz+1;t[now].sum=t[l].sum+t[r].sum+t[now].c.cnt;
    	for(register int i=0;i<=1;i++){
    		t[now].mi[i]=t[now].mx[i]=t[now].c.x[i];
    		if(l) t[now].mi[i]=min(t[now].mi[i],t[l].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[l].mx[i]);
    		if(r) t[now].mi[i]=min(t[now].mi[i],t[r].mi[i]),t[now].mx[i]=max(t[now].mx[i],t[r].mx[i]);
    	}
    }
    inline int newnode(){
    	if(top) return rub[top--];
    	else return ++tot;
    }
    inline int build(int l,int r,int d){
    	if(l>r) return 0;
    	int now=newnode(),mid=(l+r)>>1;
    	D=d,nth_element(p+l,p+mid,p+r+1);
    	t[now].c=p[mid],ls[now]=build(l,mid-1,d^1),rs[now]=build(mid+1,r,d^1);
    	push_up(now); return now;
    }
    inline void clear(int now,int pos){
    	if(ls[now]) clear(ls[now],pos);
    	p[pos+t[ls[now]].sz+1]=t[now].c,rub[++top]=now;
    	if(rs[now]) clear(rs[now],pos+t[ls[now]].sz+1);
    }
    inline void check(int &now,int d){
    	if(alpha*(double)(t[now].sz)<(double)(t[ls[now]].sz)||alpha*(double)(t[now].sz)<(double)(t[rs[now]].sz)){
    		clear(now,0);now=build(1,t[now].sz,d);
    	}
    }
    inline void insert(int &now,point p,int d){
    	if(!now){
    		now=newnode();ls[now]=rs[now]=0;t[now].c=p;push_up(now);return;
    	}
    	if(p.x[d]<=t[now].c.x[d]){
    		insert(ls[now],p,d^1);
    	}else{
    		insert(rs[now],p,d^1);
    	}
    	push_up(now);check(now,d);
    }
    int chkin(int now,sqr tp){
    	return (!(t[now].mx[0]<tp.x1||t[now].mi[0]>tp.x2||t[now].mx[1]<tp.y1||t[now].mi[1]>tp.y2));
    }
    int totalin(int now,sqr tp){
    	return (t[now].mx[0]<=tp.x2&&t[now].mi[0]>=tp.x1&&t[now].mx[1]<=tp.y2&&t[now].mi[1]>=tp.y1);
    }
    int ptin(point a,sqr b){
    	return (b.x2>=a.x[0]&&b.x1<=a.x[0]&&b.y1<=a.x[1]&&b.y2>=a.x[1]);
    }
    inline int query(int now,sqr tp){
    	if(!now) return 0;
    	int re=0;
    	if(totalin(now,tp)){
    		return t[now].sum;
    	}else if(!chkin(now,tp)) return 0;
    	if(ptin(t[now].c,tp)) re+=t[now].c.cnt;
    	int l=ls[now],r=rs[now];
    	re+= query(l,tp);
    	re+= query(r,tp);
    	return re;
    }
    int main(){
    	int qqq=read(),ppp=read(),opt;//前两个数并没有什么用
    	while(opt=read())
    		if(opt==1){
    			pn.x[0]=(read()),pn.x[1]=(read()),pn.cnt=(read());
    			insert(rt,pn,0);
    		}else if(opt==2){
    			q.x1=(read()),q.y1=(read()),q.x2=(read()),q.y2=(read());
    			ans=query(rt,q);printf("%d
    ",ans);
    		}else return 0;
    	return 0;
    }
    

    K-D 树优化建边

    NOI 2019考到了所以写一写

    竟然1A了。。。(可能是之前一些KDT的题调了好久所以比较熟悉

    思路跟线段树的差不多,这题不过空间开不下,所以考虑不保存边

    考虑dijkstra算法中每个点只能作为中间节点松弛连的节点一次(vis)

    于是建边的复杂度就跟每次直接K-D树上查询复杂度一样啦

    具体来说,

    1. 如果当前点是原来的点,直接上树查询并松弛
    2. 如果是树上的点,它不可能再向树上区间连边,只连向它的左右儿子和对应的原点

    码量也不是很大(还没有splay大),注意细节

    #include<iostream>
    #include<cstdio>
    #include<algorithm>
    #include<cstring>
    #include<queue>
    #define inf 1926081700;
    using namespace std;
    int read(){
    	int x=0,pos=1;char ch=getchar();
    	for(;!isdigit(ch);ch=getchar()) if(ch=='-') pos=0;
    	for(;isdigit(ch);ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
    	return pos?x:-x;
    } 
    const int N = 75001;
    struct point{
    	int x[2],ori;
    }p[N<<1];
    struct node{
    	int mx[2],mi[2],sz,ord;
    	point c;
    }t[N<<1];
    int ls[N<<1],rs[N<<1];
    int n,m,w,h,tot,D;
    int operator < (point a,point b){
    	return a.x[D]<b.x[D];
    }
    int operator > (point a,point b){
    	return a.x[D]>b.x[D];
    }
    inline void push_up(int now){
    	int l=ls[now],r=rs[now];
    	t[now].sz=t[l].sz+t[r].sz+1;
    	t[now].mi[0]=t[now].mx[0]=t[now].c.x[0];t[now].mi[1]=t[now].mx[1]=t[now].c.x[1];
    	if(l) t[now].mi[0]=min(t[now].mi[0],t[l].mi[0]),t[now].mi[1]=min(t[now].mi[1],t[l].mi[1]),t[now].mx[0]=max(t[now].mx[0],t[l].mx[0]),t[now].mx[1]=max(t[now].mx[1],t[l].mx[1]);
    	if(r) t[now].mi[0]=min(t[now].mi[0],t[r].mi[0]),t[now].mi[1]=min(t[now].mi[1],t[r].mi[1]),t[now].mx[0]=max(t[now].mx[0],t[r].mx[0]),t[now].mx[1]=max(t[now].mx[1],t[r].mx[1]);
    }
    inline void build(int &now,int l,int r,int d){
    	if(l>r) return; 
    	now=++tot;int mid=(l+r)>>1;
    	D=d;nth_element(p+l,p+mid,p+r+1);t[now].c=p[mid];t[now].ord=p[mid].ori;
    	build(ls[now],l,mid-1,d^1);build(rs[now],mid+1,r,d^1);
    	push_up(now);
    } 
    struct sqr{
    	int x1,x2,y1,y2,w;
    }qu[N<<1];
    struct graph{
    	int v,nex;
    }edge[N<<1];
    int tope=0,head[N],dis[N<<1],vis[N<<1],rt;
    void add(int u,int v){
    	edge[++tope].v=v;
    	edge[tope].nex=head[u];
    	head[u]=tope;
    }
    struct type{
    	int pt,w;
    };
    struct cmp{
    	int operator()(type a,type b){
    		return a.w>b.w;
    	}
    };
    priority_queue<type,vector<type>,cmp> q;
    inline type mk(int a,int b){
    	type nw;nw.pt=a,nw.w=b;return nw;
    }
    inline void relax(int u,int v,int w){
    	if(dis[v]>dis[u]+w){
    		dis[v]=dis[u]+w;
    		if(!vis[v]){
    			q.push(mk(v,dis[v]));
    		}
    	}
    }
    inline int totalin(int now,sqr tp){
    	return (t[now].mi[0]>=tp.x1&&t[now].mx[0]<=tp.x2&&t[now].mi[1]>=tp.y1&&t[now].mx[1]<=tp.y2); 
    }
    inline int totalout(int now,sqr tp){
    	return (t[now].mx[0]<tp.x1||t[now].mi[0]>tp.x2||t[now].mx[1]<tp.y1||t[now].mi[1]>tp.y2); 
    }
    inline int ptin(point now,sqr tp){
    	return (now.x[0]>=tp.x1&&now.x[0]<=tp.x2&&now.x[1]>=tp.y1&&now.x[1]<=tp.y2); 
    }
    inline void query(int now,sqr tp,int u){
    	if(totalin(now,tp)){
    		relax(u,now,tp.w);
    		return;
    	}
    	if(ptin(t[now].c,tp)) relax(u,t[now].ord,tp.w);
    	int l=ls[now],r=rs[now];
    	if(!totalout(l,tp)) query(l,tp,u);
    	if(!totalout(r,tp)) query(r,tp,u); 
    }
    inline void dijkstra(){
    	q.push(mk(1,0));dis[1]=0;
    	for(int i=2;i<=tot;i++){
    		dis[i]=inf;
    	}
    	while(!q.empty()){
    		int now=q.top().pt;q.pop();
    		if(vis[now]) continue;else vis[now]=1;
    		if(now<=n){
    			for(int i=head[now];i;i=edge[i].nex){
    				int v=edge[i].v;
    				query(rt,qu[v],now);
    			}
    		}else{
    			relax(now,ls[now],0);
    			relax(now,rs[now],0);
    			relax(now,t[now].ord,0); 
    		}
    	}
    	for(int i=2;i<=n;i++){
    		printf("%d
    ",dis[i]);
    	}
    }
    int main(){
    	n=read(),m=read(),w=read(),h=read();
    	for(int i=1;i<=n;i++){
    		p[i].x[0]=read(),p[i].x[1]=read(),p[i].ori=i;
    	}
    	tot=n;
    	build(rt,1,n,1);
    	for(int i=1;i<=m;i++){
    		int u=read();
    		qu[i].w=read(),qu[i].x1=read(),qu[i].x2=read(),qu[i].y1=read(),qu[i].y2=read();
    		add(u,i);
    	}
    	dijkstra();
    	return 0;
    }
    

    后记

    感觉数据结构也学的差不多了吧。。。

    之后可能会写的数据结构博客:

    top-tree/李超线段树/势能线段树/毒瘤分块题

  • 相关阅读:
    webStorm常用快捷键
    npm 常用指令
    webpack配置详解
    Tornado-StaticFileHandler参考
    python-希尔排序
    python的__init__几种方法总结
    gitlab和github一起使用
    Git的一些知识
    关于Django的理解
    python-快速排序
  • 原文地址:https://www.cnblogs.com/lcyfrog/p/11624855.html
Copyright © 2011-2022 走看看