1 import numpy as np 2 arr = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]) 3 arr.shape 4 5 class KDTree(): 6 def __init__(self): 7 self.value = None 8 self.left = None 9 self.right = None 10 self.axis = None 11 12 def create(arr, k, h=0): 13 if arr.shape[0] == 0: 14 return None 15 tree = KDTree() 16 axis = h % k 17 18 if arr.shape[0] == 1: 19 tree.value = arr[0] 20 tree.left = None 21 tree.right = None 22 tree.axis = axis 23 else: 24 arr = sorted(arr, key = lambda x:x[axis]) 25 arr = np.array(arr) 26 i = arr.shape[0]//2 27 28 tree.value = arr[i] 29 tree.left = create(arr[0:i], k, h+1) 30 tree.right = create(arr[i+1:], k, h+1) 31 tree.axis = axis 32 return tree 33 34 k = KDTree() 35 36 k = create(arr, arr.shape[1]) 37 38 def preOrder(k): 39 print('当前节点:' + str(k.value)) 40 41 if k.left: 42 preOrder(k.left) 43 if k.right: 44 preOrder(k.right) 45 46 preOrder(k) 47 48 def dis(a, b): 49 return np.linalg.norm(a-b) 50 def search(kd, goal, k, h=0): 51 '''输入:kd树,目标点、特征维度k以及当前深度h''' 52 '''输出:在kd树上的与目标点距离(欧氏距离)最近的距离''' 53 if kd.left == None and kd.right == None: 54 return dis(goal, kd.value) 55 if kd.left == None: 56 return min(search(kd.right, goal, k, h+1), dis(kd.value, goal)) 57 if kd.right == None: 58 return min(search(kd.left, goal, k, h+1), dis(kd.value, goal)) 59 axis = h%k 60 61 if goal[axis] < kd.value[axis]: 62 cur_dis = search(kd.left, goal, k, h+1) 63 else: 64 cur_dis = search(kd.right, goal, k, h+1) 65 66 67 if cur_dis < kd.value[axis]-goal[axis]:////cut 取绝对值 68 return cur_dis; 69 else: 70 if goal[axis] < kd.value[axis]: 71 cur_dis = min(search(kd.right, goal, k, h+1), cur_dis, dis(kd.value, goal)) 72 else: 73 cur_dis = min(search(kd.left, goal, k, h+1), cur_dis, dis(kd.value, goal)) 74 return cur_dis 75 76 search(k, np.array([9, 6]), 2)