zoukankan      html  css  js  c++  java
  • kd树的创建和求最近邻

     1 import numpy as np
     2 arr = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
     3 arr.shape
     4 
     5 class KDTree():
     6     def __init__(self):
     7         self.value = None
     8         self.left = None
     9         self.right = None
    10         self.axis = None
    11 
    12 def create(arr, k, h=0):
    13     if arr.shape[0] == 0:
    14         return None
    15     tree = KDTree()
    16     axis = h % k
    17     
    18     if arr.shape[0] == 1:
    19         tree.value = arr[0]
    20         tree.left = None
    21         tree.right = None
    22         tree.axis = axis
    23     else:
    24         arr = sorted(arr, key = lambda x:x[axis])
    25         arr = np.array(arr)
    26         i = arr.shape[0]//2
    27         
    28         tree.value =  arr[i]
    29         tree.left = create(arr[0:i], k, h+1)
    30         tree.right = create(arr[i+1:], k, h+1)
    31         tree.axis = axis
    32     return tree
    33 
    34 k = KDTree()
    35 
    36 k = create(arr, arr.shape[1])
    37 
    38 def preOrder(k):
    39     print('当前节点:' + str(k.value))
    40     
    41     if k.left:
    42         preOrder(k.left)
    43     if k.right:
    44         preOrder(k.right)
    45 
    46 preOrder(k)
    47 
    48 def dis(a, b):
    49     return np.linalg.norm(a-b)
    50 def search(kd, goal, k, h=0):
    51     '''输入:kd树,目标点、特征维度k以及当前深度h'''
    52     '''输出:在kd树上的与目标点距离(欧氏距离)最近的距离'''
    53     if kd.left == None and kd.right == None:
    54         return dis(goal, kd.value)
    55     if kd.left == None:
    56         return min(search(kd.right, goal, k, h+1), dis(kd.value, goal))
    57     if kd.right == None:
    58         return min(search(kd.left, goal, k, h+1), dis(kd.value, goal))
    59     axis = h%k
    60     
    61     if goal[axis] < kd.value[axis]:
    62         cur_dis = search(kd.left, goal, k, h+1)
    63     else:
    64         cur_dis = search(kd.right, goal, k, h+1)
    65     
    66     
    67     if cur_dis < kd.value[axis]-goal[axis]:////cut  取绝对值
    68         return cur_dis;
    69     else:
    70         if goal[axis] < kd.value[axis]:
    71             cur_dis = min(search(kd.right, goal, k, h+1), cur_dis, dis(kd.value, goal))
    72         else:
    73             cur_dis = min(search(kd.left, goal, k, h+1), cur_dis, dis(kd.value, goal))
    74     return cur_dis
    75 
    76 search(k, np.array([9, 6]), 2)
  • 相关阅读:
    使用Astah画UML类图经验总结
    Qt的四个常见的图像叠加模式
    获取Linux时间函数
    DBus学习网站
    线程属性pthread_attr_t简介
    Secure CRT 自动记录日志log配置
    MySQL的group_concat()函数合并多行数据
    MySQL的Limit详解
    异步查询json传日期格式到前台,变成了时间戳的格式
    启动studio报错Gradle error
  • 原文地址:https://www.cnblogs.com/liuwenhan/p/11723354.html
Copyright © 2011-2022 走看看