最近几天学了一发K-Dtree,有一点理解。。
首先K-Dtree是一种算法。类似于搜索,但是如果你硬要叫它数据结构也可以。。
K-D树在形态上是一颗二叉排序树,满足左儿子权值小于根节点,根节点权值小于右儿子,由于每个K-D树节点中都有对应的点,那么怎么划分权值就成为了问题。
为了把数据分散的更好,我们可以选择对每一个维度挨个枚举然后进行划分,这时候就要用到std的一个stl了,在algorithm里,nth_element(&a[l],&a[mid],&a[r+1],cmp),cmp函数要自己写,a是Point结构体的一个数组。
struct Point{ ll d[2],val; inline ll& operator [] (int x){return d[x];} inline bool operator < (const Point &a)const{ return d[now]==a.d[now]?d[now^1]<a.d[now^1]:d[now]<a.d[now]; } }a[MAXN];这份代码里直接重载了<,没有写cmp函数。now是当前划分的维度,在这里划分维度的标准是挨个分,而不是按方差(见其他K-D树讲解);
如何建树?直接上代码吧。
void build(node *&o,int l,int r,int d=0){ if(l>r)return; now = d;int mid = l+r>>1; std::nth_element(&pt[l],&pt[mid],&pt[r+1]); o = new node(pt[mid]); build(o->ls,l,mid-1,d^1); build(o->rs,mid+1,r,d^1); o->Maintain(); }
struct node{ node *ls,*rs; Point point; ll mn[2],mx[2],sum; inline void Update(node *p){ if(!p)return; for(int i=0;i<=1;i++)mn[i]=min(mn[i],p->mn[i]); for(int i=0;i<=1;i++)mx[i]=max(mx[i],p->mx[i]); } inline void Maintain(){ sum = point.val; if(ls)Update(ls),sum+=ls->sum; if(rs)Update(rs),sum+=rs->sum; } }*root;
然后就是查找了。
维护每个子树所对应的矩形(最大最小x,y坐标。)
然后可以把对应的min_dis()函数当做估价函数,来进行搜索QAQ。
还是直接上代码:
#include <stdio.h> #include <cstring> #include <iostream> #include <queue> #include <algorithm> using std::max; using std::min; typedef long long ll; const ll inf = (ll)1e16; const int MAXN = 100005; int now,n,m,k; template<typename _t> inline _t read(){ _t x=0,f=1; char ch=getchar(); for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-f; for(;isdigit(ch);ch=getchar())x=x*10+(ch^48); return x*f; } struct Point{ ll a[3]; inline bool operator < (const Point & b)const{ return a[now]<b.a[now]||(a[now]==b.a[now]&&a[now^1]<b.a[now^1]); } ll &operator [](int x){return a[x];} }pt[MAXN],cmp; struct res{ ll dis,id; bool operator < (const res & a)const{ return dis == a.dis? id < a.id : dis > a.dis; } }; std::priority_queue<res>Q; inline ll sqr(ll x){return x*x;} inline ll dis(Point x,Point y){return sqr(x[0]-y[0])+sqr(x[1]-y[1]);} struct node{ node *ls,*rs; Point point; int mn[2],mx[2]; node(Point &x){ point = x; ls = rs = NULL; mn[0]=mx[0]=x[0]; mn[1]=mx[1]=x[1]; } inline void Maintain(node *x){ if(x==NULL)return; for(int i=0;i<=1;i++)mn[i]=min(mn[i],x->mn[i]); for(int i=0;i<=1;i++)mx[i]=max(mx[i],x->mx[i]); } inline ll calc_dis(){ ll Ans = 0; Ans = max(Ans,dis((Point){mn[0],mn[1]},cmp)); Ans = max(Ans,dis((Point){mn[0],mx[1]},cmp)); Ans = max(Ans,dis((Point){mx[0],mn[1]},cmp)); Ans = max(Ans,dis((Point){mx[0],mx[1]},cmp)); return Ans; } }*root; void build(node *&o,int l,int r,int d){ if(l>r)return; int mid = l+r>>1; now = d;std::nth_element(&pt[l],&pt[mid],&pt[r+1]); o = new node(pt[mid]); build(o->ls,l,mid-1,d^1); build(o->rs,mid+1,r,d^1); o->Maintain(o->ls); o->Maintain(o->rs); } inline void Query(node *rt){ if(rt==NULL)return; if(Q.size()==k&&rt->calc_dis()<Q.top().dis)return; res ans = (res){dis(rt->point,cmp),rt->point[2]}; if(Q.size()<k)Q.push(ans); else if(ans<Q.top())Q.pop(),Q.push(ans);//这个东西重载了QAQ。。 ll dis_ls = rt->ls==NULL?inf:rt->ls->calc_dis(); ll dis_rs = rt->rs==NULL?inf:rt->rs->calc_dis(); if(dis_ls>dis_rs){ Query(rt->ls); if(dis_rs>=Q.top().dis||Q.size()<k)Query(rt->rs); } else{ Query(rt->rs); if(dis_ls>=Q.top().dis||Q.size()<k)Query(rt->ls); } } int main(){ n=read<int>(); for(int i=1;i<=n;++i)pt[i][0]=read<int>(),pt[i][1]=read<int>(),pt[i][2]=i; build(root,1,n,0); m=read<int>(); for(int i=1;i<=m;++i){ cmp[0]=read<int>();cmp[1]=read<int>(); k=read<int>(); while(!Q.empty())Q.pop(); Query(root); printf("%d ",Q.top().id); } }这是BZOJ 2626的完整代码。
题目大意:求第k远的点。
那么就维护一个堆,大小为k。。然后就可以了。
BZOJ 1941
求最远和最近的点。
用类似思路就可以了
用估价函数来判断先去哪个儿子。
#include <stdio.h> #include <cstring> #include <iostream> #include <algorithm> typedef long long ll; const ll inf = 0x3f3f3f3f3f3f3f3fll; const int MAXN = 500005; using std::min; using std::max; int n,now; ll Ans_min,Ans_max,Ans=inf; template<typename _t> inline _t read(){ _t x=0,f=1; char ch=getchar(); for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-f; for(;isdigit(ch);ch=getchar())x=x*10+(ch^48); return x*f; } struct Point{ ll d[2]; ll& operator [] (const int x){return d[x];} inline bool operator != (const Point &b)const{ return d[0]!=b.d[0]||d[1]!=b.d[1]; } bool operator < (const Point x)const{ return d[now]<x.d[now]||(d[now]==x.d[now]&&d[now^1]<x.d[now^1]); } }pt[MAXN],cur,cpy[MAXN]; inline ll dis(Point a,Point b){return abs(a[0]-b[0])+abs(a[1]-b[1]);} struct node{ node *ls,*rs; Point point; ll mn[2],mx[2]; node(Point x){ ls=rs=NULL; point = x; mn[0]=mx[0]=x[0]; mn[1]=mx[1]=x[1]; } inline void Maintain(node *x){ if(x==NULL)return; for(int i=0;i<=1;++i)mn[i]=min(mn[i],x->mn[i]); for(int i=0;i<=1;++i)mx[i]=max(mx[i],x->mx[i]); } inline ll min_dis(){ ll ans = 0; ans += max(mn[0]-cur[0],0ll)+max(cur[0]-mx[0],0ll); ans += max(mn[1]-cur[1],0ll)+max(cur[1]-mx[1],0ll); return ans; } inline ll max_dis(){ ll ans = 0; ans += max(abs(cur[0]-mn[0]),abs(cur[0]-mx[0])); ans += max(abs(cur[1]-mn[1]),abs(cur[1]-mx[1])); return ans; } }*root; void build(node *&o,int l,int r,int d=0){ if(l>r)return; int mid = l+r>>1;now=d; std::nth_element(&pt[l],&pt[mid],&pt[r+1]); o = new node(pt[mid]); build(o->ls,l,mid-1,d^1); build(o->rs,mid+1,r,d^1); o->Maintain(o->ls); o->Maintain(o->rs); return ; } void Query_min(node *o){ if(o==NULL)return; if(o->point!=cur)Ans_min=min(Ans_min,dis(cur,o->point)); ll dis_l = o->ls?o->ls->min_dis():inf; ll dis_r = o->rs?o->rs->min_dis():inf; if(dis_l<dis_r){ if(o->ls)Query_min(o->ls); if(dis_r<=Ans_min&&o->rs)Query_min(o->rs); } else{ if(o->rs)Query_min(o->rs); if(dis_l<=Ans_min&&o->ls)Query_min(o->ls); } } void Query_max(node *o){ if(o==NULL)return; if(o->point!=cur)Ans_max=max(Ans_max,dis(cur,o->point)); ll dis_l = o->ls?o->ls->max_dis():inf; ll dis_r = o->rs?o->rs->max_dis():inf; if(dis_l>dis_r){ if(o->ls)Query_max(o->ls); if(dis_r>=Ans_max&&o->rs)Query_max(o->rs); } else{ if(o->rs)Query_max(o->rs); if(dis_l>=Ans_max&&o->ls)Query_max(o->ls); } } inline ll Query_max(Point p){ Ans_max = -inf;cur = p; Query_max(root); return Ans_max; } inline ll Query_min(Point p){ Ans_min=inf;cur=p; Query_min(root); return Ans_min; } int main(){ n=read<int>(); for(int i=1;i<=n;i++){ pt[i][0]=read<int>(); pt[i][1]=read<int>(); cpy[i]=pt[i]; } build(root,1,n); for(int i=1;i<=n;i++)Ans = min(Query_max(cpy[i])-Query_min(cpy[i]),Ans); printf("%lld ",Ans); }BZOJ 4520 和2626基本差不多。对总体维护一个大小为2*k的堆就行了,因为每个点都被算了两遍,所以要2*k。
#include <stdio.h> #include <cstring> #include <iostream> #include <queue> #include <algorithm> using namespace std; typedef long long ll; const ll inf = (ll)1e16; const int MAXN = 100005; int now,n,m,k; template<typename _t> inline _t read(){ _t x=0,f=1; char ch=getchar(); for(;!isdigit(ch);ch=getchar())if(ch=='-')f=-f; for(;isdigit(ch);ch=getchar())x=x*10+(ch^48); return x*f; } struct Point{ ll a[2]; inline bool operator < (const Point & b)const{ return a[now]<b.a[now]||(a[now]==b.a[now]&&a[now^1]<b.a[now^1]); } ll &operator [](int x){return a[x];} }pt[MAXN],cmp; priority_queue<ll,vector<ll> , greater<ll> >Q; inline ll sqr(ll x){return x*x;} inline ll dis(Point x,Point y){return sqr(x[0]-y[0])+sqr(x[1]-y[1]);} struct node{ node *ls,*rs; Point point; int mn[2],mx[2]; node(Point &x){ point = x; ls = rs = NULL; mn[0]=mx[0]=x[0]; mn[1]=mx[1]=x[1]; } inline void Maintain(node *x){ if(x==NULL)return; for(int i=0;i<=1;i++)mn[i]=min(mn[i],x->mn[i]); for(int i=0;i<=1;i++)mx[i]=max(mx[i],x->mx[i]); } inline ll calc_dis(){ ll Ans = 0; Ans = max(Ans,dis((Point){mn[0],mn[1]},cmp)); Ans = max(Ans,dis((Point){mn[0],mx[1]},cmp)); Ans = max(Ans,dis((Point){mx[0],mn[1]},cmp)); Ans = max(Ans,dis((Point){mx[0],mx[1]},cmp)); return Ans; } }*root; void build(node *&o,int l,int r,int d){ if(l>r)return; int mid = l+r>>1; now = d;nth_element(&pt[l],&pt[mid],&pt[r+1]); o = new node(pt[mid]); build(o->ls,l,mid-1,d^1); build(o->rs,mid+1,r,d^1); o->Maintain(o->ls); o->Maintain(o->rs); } inline void Query(node *rt){ if(rt==NULL)return; if(Q.size()==k&&rt->calc_dis()<Q.top())return; ll ans = dis(rt->point,cmp); if(Q.size()<k)Q.push(ans); else if(ans>Q.top())Q.pop(),Q.push(ans); ll dis_ls = rt->ls==NULL?inf:rt->ls->calc_dis(); ll dis_rs = rt->rs==NULL?inf:rt->rs->calc_dis(); if(dis_ls>dis_rs){ Query(rt->ls); if(dis_rs>=Q.top()||Q.size()<k)Query(rt->rs); } else{ Query(rt->rs); if(dis_ls>=Q.top()||Q.size()<k)Query(rt->ls); } } int main(){ n=read<int>();k=read<int>();k<<=1; for(int i=1;i<=n;++i)pt[i][0]=read<int>(),pt[i][1]=read<int>(); build(root,1,n,0); for(int i=1;i<=n;i++)cmp=pt[i],Query(root); printf("%lld ",Q.top()); }其余题目:
bzoj2989 带插入K-D树+替罪羊思想。
bzoj4066 和2989思路差不多。
bzoj2850 巧克力王国