zoukankan      html  css  js  c++  java
  • KNN算法之KD树(K-dimension Tree)实现 K近邻查询

    KD树是一种分割k维数据空间的数据结构,主要应用于多维空间关键数据的搜索,如范围搜索和最近邻搜索。

    KD树使用了分治的思想,对比二叉搜索树(BST),KD树解决的是多维空间内的最近点(K近点)问题。(思想与之前见过的最近点对问题很相似,将所有点分为两边,对于可能横跨划分线的点对再进一步讨论)

    KD树用来优化KNN算法中的查询复杂度。

    一、建树

    建立KDtree,主要有两步操作:选择合适的分割维度选择中值节点作为分割节点

    分割维度的选择遵循的原则是,选择范围最大的纬度,也即是方差最大的纬度作为分割维度,为什么方差最大的适合作为特征呢?

    因为方差大,数据相对“分散”,选取该特征来对数据集进行分割,数据散得更“开”一些。

    分割节点的选择原则是,将这一维度的数据进行排序,选择正中间的节点作为分割节点,确保节点左边的点的维度值小于节点的维度值,节点右边的点的维度值大于节点的维度值。

    这两步步骤影响搜索效率,非常关键。

    二、搜索K近点

    需要的数据结构:最大堆(此处我对距离取负从而用最小堆实现的最大堆,因为python的heapq模块只有最小堆)、堆栈(用列表实现)

    a.利用二叉搜索找到叶子节点并将搜索的结点路径压入堆栈stack中

    b.通过保存在堆栈中的搜索路径回溯,直至堆栈中没有结点了

    对于b步骤,需要区分叶子结点和非叶结点:

    1、叶子结点

    叶子结点:计算与目标点的距离。若候选堆中不足K个点,将叶子结点加入候选堆中;如果K个点够了,判断是否比候选堆中距离最小的结点(因为距离取了相反数)还要大, 说明应当加入候选堆;

    2、非叶结点

    对于非叶结点,处理步骤和叶子结点差不多,只是需要额外考虑以目标点为圆心,最大堆的堆顶元素为半径的超球体是否和划分当前空间的超平面相交,如果相交说明未访问的另一边的空间有可能包含比当前已有的K个近点更近的点,需要搜索另一边的空间;此外,当候选堆中没有K个点,那么不管有没有相交,都应当搜索未访问的另一边空间,因为候选堆的点不够K个。

    步骤:计算与目标点的距离
    1、若不足K个点,将结点加入候选堆中;
    如果K个点够了,判断是否比候选堆中距离最小的结点(因为距离取了相反数)还要大。

    2、判断候选堆中的最小距离是否小于Xi离当前超平面的距离(即是否需要判断未访问的另一边要不要搜索)当然如果不足K个点,虽然超平面不相交,依旧要搜索另一边,直到找到叶子结点,并且把路径加入回溯栈中。

    三、预测

    KNN通常用来分类或者回归问题,笔者已经封装好了两种预测的方法。

    python代码实现:

      1 import heapq
      2 class KDNode(object):
      3     def __init__(self,feature=None,father=None,left=None,right=None,split=None):
      4         self.feature=feature # dimension index (按第几个维度的特征进行划分的)
      5         self.father=father  #并没有用到
      6         self.left=left
      7         self.right=right
      8         self.split=split # X value and Y value (元组,包含特征X和真实值Y)
      9 
     10 class KDTree(object):
     11     def __init__(self):
     12         self.root=KDNode()
     13         pass
     14 
     15     def _get_variance(self,X,row_indexes,feature_index):
     16         # X (2D list): samples * dimension
     17         # row_indexes (1D list): choose which row can be calculated
     18         # feature_index (int): calculate which dimension
     19         n = len(row_indexes)
     20         sum1 = 0
     21         sum2 = 0
     22         for id in row_indexes:
     23             sum1 = sum1 + X[id][feature_index]
     24             sum2 = sum2 + X[id][feature_index]**2
     25 
     26         return sum2/n - (sum1/n)**2
     27 
     28     def _get_max_variance_feature(self,X,row_indexes):
     29         mx_var = -1
     30         dim_index = -1
     31         for dim in range(len(X[0])):
     32             dim_var = self._get_variance(X,row_indexes,dim)
     33             if dim_var>mx_var:
     34                 mx_var=dim_var
     35                 dim_index=dim
     36         # return max variance feature index (int)
     37         return dim_index
     38 
     39     def _get_median_index(self,X,row_indexes,feature_index):
     40         median_index =  len(row_indexes)//2
     41         select_X = [(idx,X[idx][feature_index]) for idx in row_indexes]
     42         sorted_X = select_X
     43         sorted(sorted_X,key= lambda x:x[1])
     44         #return median index in feature_index dimension (int)
     45         return sorted_X[median_index][0]
     46 
     47     def _split_feature(self,X,row_indexes,feature_index,median_index):
     48         left_ids = []
     49         right_ids = []
     50         median_val = X[median_index][feature_index]
     51         for id in row_indexes:
     52             if id==median_index:
     53                 continue
     54             val = X[id][feature_index]
     55             if val < median_val:
     56                 left_ids.append(id)
     57             else:
     58                 right_ids.append(id)
     59         # return (left points index and right points index)(list,list)
     60         # 把当前的样本按feature维度进行划分为两份
     61         return left_ids, right_ids
     62 
     63     def build_tree(self,X,Y):
     64         row_indexes =[i for i in range(len(X))]
     65         node =self.root
     66         queue = [(node,row_indexes)]
     67         # BFS创建KD树
     68         while queue:
     69             root,ids = queue.pop(0)
     70             if len(ids)==1:
     71                 root.feature = 0 #如果是叶子结点,维度赋0
     72                 root.split = (X[ids[0]],Y[ids[0]])
     73                 continue
     74             # 选取方差最大的特征维度划分,取样本的中位数作为median
     75             feature_index = self._get_max_variance_feature(X,ids)
     76             median_index = self._get_median_index(X,ids,feature_index)
     77             left_ids,right_ids = self._split_feature(X,ids,feature_index,median_index)
     78             root.feature = feature_index
     79             root.split = (X[median_index],Y[median_index])
     80             if left_ids:
     81                 root.left = KDNode()
     82                 root.left.father = root
     83                 queue.append((root.left,left_ids))
     84             if right_ids:
     85                 root.right = KDNode()
     86                 root.right.father = root
     87                 queue.append((root.right,right_ids))
     88 
     89     def _get_distance(self,Xi,node,p=2):
     90         # p=2  default Euclidean distance
     91         nx = node.split[0]
     92         dist = 0
     93         for i in range(len(Xi)):
     94             dist=dist + (abs(Xi[i]-nx[i])**p)
     95         dist = dist**(1/p)
     96         return dist
     97 
     98     def _get_hyperplane_distance(self,Xi,node):
     99         xx = node.split[0]
    100         dist = abs(Xi[node.feature] - xx[node.feature])
    101         return dist
    102 
    103     def _is_leaf(self,node):
    104         if node.left is None and node.right is None:
    105             return True
    106         else:
    107             return False
    108 
    109     def get_nearest_neighbour(self,Xi,K=1):
    110         search_paths = []
    111         max_heap = [] #use min heap achieve max heap (因为python只有最小堆)
    112         priority_num = 0  # remove same distance
    113         heapq.heappush(max_heap,(float('-inf'),priority_num,None))
    114         priority_num +=1
    115         node = self.root
    116         # 找到离Xi最近的叶子结点
    117         while node is not None:
    118             search_paths.append(node)
    119             if Xi[node.feature] < node.split[0][node.feature]:
    120                 node = node.left
    121             else:
    122                 node = node.right
    123 
    124         while search_paths:
    125             now = search_paths.pop()
    126             #  叶子结点:计算与Xi的距离,若不足K个点,将叶子结点加入候选堆中;
    127             #  如果K个点够了,判断是否比候选堆中距离最小的结点(因为距离取了相反数)还要大,
    128             #  说明应当加入候选堆;
    129             if self._is_leaf(now):
    130                 dist = self._get_distance(Xi,now)
    131                 dist = -dist
    132                 mini_dist = max_heap[0][0]
    133                 if len(max_heap) < K :
    134                     heapq.heappush(max_heap,(dist,priority_num,now))
    135                     priority_num+=1
    136                 elif dist > mini_dist:
    137                     _ = heapq.heappop(max_heap)
    138                     heapq.heappush(max_heap,(dist,priority_num,now))
    139                     priority_num+=1
    140             #  非叶结点:计算与Xi的距离
    141             #  1、若不足K个点,将结点加入候选堆中;
    142             #   如果K个点够了,判断是否比候选堆中距离最小的结点(因为距离取了相反数)还要大,
    143             #  2、判断候选堆中的最小距离是否小于Xi离当前超平面的距离(即是否需要判断另一边要不要搜索)
    144             #    当然如果不足K个点,虽然超平面不相交,依旧要搜索另一边,
    145             #    直到找到叶子结点,并且把路径加入回溯栈中
    146             else :
    147                 dist = self._get_distance(Xi, now)
    148                 dist = -dist
    149                 mini_dist = max_heap[0][0]
    150                 if len(max_heap) < K:
    151                     heapq.heappush(max_heap, (dist, priority_num, now))
    152                     priority_num += 1
    153                 elif dist > mini_dist:
    154                     _ = heapq.heappop(max_heap)
    155                     heapq.heappush(max_heap, (dist, priority_num, now))
    156                     priority_num += 1
    157 
    158                 mini_dist = max_heap[0][0]
    159                 if len(max_heap)<K or -(self._get_hyperplane_distance(Xi,now)) > mini_dist:
    160                     # search another child tree
    161                     if Xi[now.feature] >= now.split[0][now.feature]:
    162                         child_node = now.left
    163                     else:
    164                         child_node = now.right
    165                     # record path until find child leaf node
    166                     while child_node is not None:
    167                         search_paths.append(child_node)
    168                         if Xi[child_node.feature] < child_node.split[0][child_node.feature]:
    169                             child_node = child_node.left
    170                         else:
    171                             child_node = child_node.right
    172         return max_heap
    173 
    174     def predict_classification(self,Xi,K=1):
    175         # 多分类问题预测
    176         y =self.get_nearest_neighbour(Xi,K)
    177         mp = {}
    178         for i in y:
    179             if i[2].split[1] in mp:
    180                 mp[i[2].split[1]]+=1
    181             else:
    182                 mp[i[2].split[1]]=1
    183         pre_y = -1
    184         max_cnt =-1
    185         for k,v in mp.items():
    186             if v>max_cnt:
    187                 max_cnt=v
    188                 pre_y=k
    189         return pre_y
    190 
    191     def predict_regression(self,Xi,K=1):
    192         #回归问题预测
    193         pre_y = self.get_nearest_neighbour(Xi,K)
    194         return sum([i[2].split[1] for i in pre_y])/K
    195 
    196 
    197 # t =KDTree()
    198 # xx = [[3,3],[1,2],[5,6],[999,999],[5,5]]
    199 # z = [1,0,1,1,1]
    200 # t.build_tree(xx,z)
    201 # y=t.predict_regression([4,5.5],K=5)
    202 # print(y)

    参考资料:

    《统计学习方法》——李航著

    https://blog.csdn.net/qq_32478489/article/details/82972391?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param

    https://zhuanlan.zhihu.com/p/45346117

    https://www.cnblogs.com/xingzhensun/p/9693362.html

  • 相关阅读:
    创建或者连接管道+++检查管道空间是否够写入本消息++++删除管道
    从instr中截取第一个delimiter之前的内容放到outstr中,返回第一个delimiter之后的位置
    把数字按网络顺序或主机顺序存放到字符串中++++把字符串按网络顺序转换成数字++++把字符串按主机顺序转换成数字
    压缩空格的函数以及BCD码与ASCII相互转换函数
    判断文件是否存在
    把指定长度字符串转换成数字
    找到特定串在源字符串中的位置
    FTP命令详解
    docker 学习路线
    云原生技术的了解
  • 原文地址:https://www.cnblogs.com/ISGuXing/p/13762636.html
Copyright © 2011-2022 走看看