zoukankan      html  css  js  c++  java
  • python kd树 搜索 代码

      kd树就是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构,可以运用在k近邻法中,实现快速k近邻搜索。构造kd树相当于不断地用垂直于坐标轴的超平面将k维空间切分,依次选择坐标轴对空间进行切分,选择训练实例点在选定坐标轴上的中位数为切分点。具体kd树的原理可以参考kd树的原理。

      代码是参考《统计学习方法》k近邻 kd树的python实现得到

      首先创建一个类,用于表示树的节点,包括:该节点的值,用于划分左右子树的切分轴,左子树,右子树

    class decisionnode:
        def __init__(self,value=None,col=None,rb=None,lb=None):
            self.value=value
            self.col=col
            self.rb=rb
            self.lb=lb

      切分点为坐标轴上的中值,下面代码求得一个序列的中值

    def median(x):
        n=len(x)
        x=list(x)
        x_order=sorted(x)
        return x_order[n//2],x.index(x_order[n//2])

    然后就可以构造一颗kd树,左子树小于切分点,右子树大于切分点,data是输入的数据

    def buildtree(x,j=0):
        rb=[]
        lb=[]
        m,n=x.shape
        if m==0: return None
        edge,row=median(x[:,j].copy())
        for i in range(m):
            if x[i][j]>edge: 
                rb.append(i)
            if x[i][j]<edge:
                lb.append(i)
        rb_x=x[rb,:]
        lb_x=x[lb,:]
        rightBranch=buildtree(rb_x,(j+1)%n)
        leftBranch=buildtree(lb_x,(j+1)%n)
        return decisionnode(x[row,:],j,rightBranch,leftBranch)

      接下来是树的搜索过程,可以用下图表示树的搜索过程,具体过程可以参考kd树的原理。

      

      代码如下:

    #搜索树:nearestPoint,nearestValue均为全局变量
    def traveltree(node,point):
        global nearestPoint,nearestValue
        if node==None: return 
        print(node.value)
        print('---')
        col=node.col
        if point[col]>node.value[col]:
            traveltree(node.rb,point)
        if point[col]<node.value[col]:
            traveltree(node.lb,point)
        dis=dist(node.value,point)
        print(dis)
        if dis<nearestValue:
            nearestPoint=node
            nearestValue=dis
            #print('nearestPoint,nearestValue' % (nearestPoint,nearestValue))
        if node.rb!=None or node.lb!=None:
            if abs(point[node.col] - node.value[node.col]) < nearestValue:
                if point[node.col]<node.value[node.col]:
                    traveltree(node.rb,point)
                if point[node.col]>node.value[node.col]:
                    traveltree(node.lb,point)
            
    def searchtree(tree,aim):
        global nearestPoint,nearestValue
        #nearestPoint=None
        nearestValue=float('inf')
        traveltree(tree,aim)
        return nearestPoint
            
        
    def dist(x1, x2): #欧式距离的计算  
        return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

     完整代码在此处取

     1 import numpy as np
     2 from numpy import array
     3 class decisionnode:
     4     def __init__(self,value=None,col=None,rb=None,lb=None):
     5         self.value=value
     6         self.col=col
     7         self.rb=rb
     8         self.lb=lb
     9         
    10 #读取数据并将数据转换为矩阵形式        
    11 def readdata(filename):    
    12     data=open(filename).readlines()
    13     x=[]
    14     for line in data:
    15         line=line.strip().split('	')
    16         x_i=[]
    17         for num in line:
    18             num=float(num)
    19             x_i.append(num)
    20         x.append(x_i)
    21     x=array(x)
    22     return x
    23 
    24 #求序列的中值    
    25 def median(x):
    26     n=len(x)
    27     x=list(x)
    28     x_order=sorted(x)
    29     return x_order[n//2],x.index(x_order[n//2])
    30 
    31 #以j列的中值划分数据,左小右大,j=节点深度%列数    
    32 def buildtree(x,j=0):
    33     rb=[]
    34     lb=[]
    35     m,n=x.shape
    36     if m==0: return None
    37     edge,row=median(x[:,j].copy())
    38     for i in range(m):
    39         if x[i][j]>edge: 
    40             rb.append(i)
    41         if x[i][j]<edge:
    42             lb.append(i)
    43     rb_x=x[rb,:]
    44     lb_x=x[lb,:]
    45     rightBranch=buildtree(rb_x,(j+1)%n)
    46     leftBranch=buildtree(lb_x,(j+1)%n)
    47     return decisionnode(x[row,:],j,rightBranch,leftBranch)
    48 
    49 #搜索树:nearestPoint,nearestValue均为全局变量
    50 def traveltree(node,point):
    51     global nearestPoint,nearestValue
    52     if node==None: return 
    53     print(node.value)
    54     print('---')
    55     col=node.col
    56     if point[col]>node.value[col]:
    57         traveltree(node.rb,point)
    58     if point[col]<node.value[col]:
    59         traveltree(node.lb,point)
    60     dis=dist(node.value,point)
    61     print(dis)
    62     if dis<nearestValue:
    63         nearestPoint=node
    64         nearestValue=dis
    65         #print('nearestPoint,nearestValue' % (nearestPoint,nearestValue))
    66     if node.rb!=None or node.lb!=None:
    67         if abs(point[node.col] - node.value[node.col]) < nearestValue:
    68             if point[node.col]<node.value[node.col]:
    69                 traveltree(node.rb,point)
    70             if point[node.col]>node.value[node.col]:
    71                 traveltree(node.lb,point)
    72         
    73 def searchtree(tree,aim):
    74     global nearestPoint,nearestValue
    75     #nearestPoint=None
    76     nearestValue=float('inf')
    77     traveltree(tree,aim)
    78     return nearestPoint
    79         
    80     
    81 def dist(x1, x2): #欧式距离的计算  
    82     return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5  
    kdtree

     

  • 相关阅读:
    [中英对照]INTEL与AT&T汇编语法对比
    用gdb理解C宏(#和##)
    Unix/Linux文件类型及访问权限
    apt-get
    查看ip地址信息和配置临时ip
    修改文件所有者 chown
    修改文件权限 chmod
    tar命令
    PHP magic_quotes_gpc
    chmod命令详细用法
  • 原文地址:https://www.cnblogs.com/bambipai/p/8436703.html
Copyright © 2011-2022 走看看