zoukankan      html  css  js  c++  java
  • kd树 求k近邻 python 代码

       之前两篇随笔介绍了kd树的原理,并用python实现了kd树的构建和搜索,具体可以参考

      kd树的原理

      python kd树 搜索 代码

      kd树常与knn算法联系在一起,knn算法通常要搜索k近邻,而不仅仅是最近邻,下面的代码将利用kd树搜索目标点的k个近邻。

      首先还是创建一个类,用于保存结点的值,左右子树,以及用于划分左右子树的切分轴

    class decisionnode:
        def __init__(self,value=None,col=None,rb=None,lb=None):
            self.value=value
            self.col=col
            self.rb=rb
            self.lb=lb

      切分点为坐标轴上的中值,下面代码求得一个序列的中值

    def median(x):
        n=len(x)
        x=list(x)
        x_order=sorted(x)
        return x_order[n//2],x.index(x_order[n//2])

      然后按照左子树大于切分点,右子树小于切分点的规则构造kd树,其中data是输入的数据

    #以j列的中值划分数据,左小右大,j=节点深度%列数    
    def buildtree(x,j=0):
        rb=[]
        lb=[]
        m,n=x.shape
        if m==0: return None
        edge,row=median(x[:,j].copy())
        for i in range(m):
            if x[i][j]>edge: 
                rb.append(i)
            if x[i][j]<edge:
                lb.append(i)
        rb_x=x[rb,:]
        lb_x=x[lb,:]
        rightBranch=buildtree(rb_x,(j+1)%n)
        leftBranch=buildtree(lb_x,(j+1)%n)
        return decisionnode(x[row,:],j,rightBranch,leftBranch)

       接下来就是搜索树得到k近邻的过程,与搜索最近邻的过程大致相同,需要创建一个字典knears,用于存储k近邻的点以及与目标点的距离(欧氏距离)

      搜索的过程为:

      (1)第一步还是遍历树,找到目标点所属区域对应的叶节点

      (2)从叶结点依次向上回退,按照寻找最近邻点的方法回退到父节点,并判断其另一个子节点对区域内是否可能存在k近邻点,具体的,在每个结点上进行以下操作:

        (a)如果字典中的成员个数不足k个,将该结点加入字典

        (b)如果字典中的成员不少于k个,判断该结点与目标结点之间的距离是否不大于字典中各结点所对应距离的的最大值,如果不大于,便将其加入到字典中

        (c)对于父节点来说,如果目标点与其切分轴之间的距离不大于字典中各结点所对应距离的的最大值,便需要访问该父节点的另一个子节点

      (3)每当字典中增加新成员,就按距离值对字典进行降序排序,将得到的列表赋值给poinelist,pointlist[0][1]便是字典中各结点所对应距离的最大值

      (4)当回退到根节点并完成对其操作时,pointlist中后k个结点就是目标点的k近邻

      代码如下:

    #搜索树:输出目标点的近邻点
    def traveltree(node,aim):
        global pointlist  #存储排序后的k近邻点和对应距离
        if node==None: return 
        col=node.col
        if aim[col]>node.value[col]:
            traveltree(node.rb,aim)
        if aim[col]<node.value[col]:
            traveltree(node.lb,aim)
        dis=dist(node.value,aim)
        if len(knears)<k:
            knears.setdefault(tuple(node.value.tolist()),dis)#列表不能作为字典的键
            pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
        elif dis<=pointlist[0][1]:
            knears.setdefault(tuple(node.value.tolist()),dis)
            pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
        if node.rb!=None or node.lb!=None:
            if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]:
                if aim[node.col]<node.value[node.col]:
                    traveltree(node.rb,aim)
                if aim[node.col]>node.value[node.col]:
                    traveltree(node.lb,aim)
        return pointlist

      完整代码在此处取

     1 import numpy as np
     2 from numpy import array
     3 class decisionnode:
     4     def __init__(self,value=None,col=None,rb=None,lb=None):
     5         self.value=value
     6         self.col=col
     7         self.rb=rb
     8         self.lb=lb
     9         
    10 #读取数据并将数据转换为矩阵形式        
    11 def readdata(filename):    
    12     data=open(filename).readlines()
    13     x=[]
    14     for line in data:
    15         line=line.strip().split('	')
    16         x_i=[]
    17         for num in line:
    18             num=float(num)
    19             x_i.append(num)
    20         x.append(x_i)
    21     x=array(x)
    22     return x
    23 
    24 #求序列的中值    
    25 def median(x):
    26     n=len(x)
    27     x=list(x)
    28     x_order=sorted(x)
    29     return x_order[n//2],x.index(x_order[n//2])
    30 
    31 #以j列的中值划分数据,左小右大,j=节点深度%列数    
    32 def buildtree(x,j=0):
    33     rb=[]
    34     lb=[]
    35     m,n=x.shape
    36     if m==0: return None
    37     edge,row=median(x[:,j].copy())
    38     for i in range(m):
    39         if x[i][j]>edge: 
    40             rb.append(i)
    41         if x[i][j]<edge:
    42             lb.append(i)
    43     rb_x=x[rb,:]
    44     lb_x=x[lb,:]
    45     rightBranch=buildtree(rb_x,(j+1)%n)
    46     leftBranch=buildtree(lb_x,(j+1)%n)
    47     return decisionnode(x[row,:],j,rightBranch,leftBranch)
    48 
    49 #搜索树:输出目标点的近邻点
    50 def traveltree(node,aim):
    51     global pointlist  #存储排序后的k近邻点和对应距离
    52     if node==None: return 
    53     col=node.col
    54     if aim[col]>node.value[col]:
    55         traveltree(node.rb,aim)
    56     if aim[col]<node.value[col]:
    57         traveltree(node.lb,aim)
    58     dis=dist(node.value,aim)
    59     if len(knears)<k:
    60         knears.setdefault(tuple(node.value.tolist()),dis)#列表不能作为字典的键
    61         pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
    62     elif dis<=pointlist[0][1]:
    63         knears.setdefault(tuple(node.value.tolist()),dis)
    64         pointlist=sorted(knears.items(),key=lambda item: item[1],reverse=True)
    65     if node.rb!=None or node.lb!=None:
    66         if abs(aim[node.col] - node.value[node.col]) < pointlist[0][1]:
    67             if aim[node.col]<node.value[node.col]:
    68                 traveltree(node.rb,aim)
    69             if aim[node.col]>node.value[node.col]:
    70                 traveltree(node.lb,aim)
    71     return pointlist
    72          
    73 def dist(x1, x2): #欧式距离的计算  
    74     return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5  
    75 
    76 knears={}
    77 k=int(input('请输入k的值'))
    78 if k<2: print('k不能是1')
    79 global pointlist
    80 pointlist=[]
    81 file=input('请输入数据文件地址')
    82 data=readdata(file)
    83 tree=buildtree(data)
    84 tmp=input('请输入目标点')
    85 tmp=tmp.split(',')
    86 aim=[]
    87 for num in tmp:
    88     num=float(num)
    89     aim.append(num)
    90 aim=tuple(aim)
    91 pointlist=traveltree(tree,aim)
    92 for point in pointlist[-k:]:
    93     print(point)
    kdtree

     

  • 相关阅读:
    springboot拦截器的拦截配置和添加多个拦截器
    ASCII对照
    爬虫出现403错误解决办法
    PhantomJS在Selenium中被标记为过时的应对措施
    Selenium 之订制启动Chrome的选项(Options)
    Selenium+PhantomJS使用时报错原因及解决方案
    python爬虫之xpath的基本使用
    JSONObject类的引用必须jar包
    selenium之使用chrome浏览器测试(附chromedriver与chrome的对应关系表)
    PhantomJS 与python的结合
  • 原文地址:https://www.cnblogs.com/bambipai/p/8443182.html
Copyright © 2011-2022 走看看