KNN算法的介绍请参考:
http://blog.csdn.net/zouxy09/article/details/16955347
统计学习方法里面给出了KD Tree的算法介绍,按照书上的进行了实现:
# -*- coding: utf-8 -*- from operator import itemgetter from copy import deepcopy import heapq class Node(object): def __init__(self, dim, label=None, parent = None, split = 0): """ kd树的节点 :param dim: 节点的向量 :param label: 节点的label :param parent: 父节点 :param split: 第 split 进行切分 :return: """ self.dim = deepcopy(dim) self.label = deepcopy(label) self.left_node = None self.right_node = None self.parent = parent self.split = split class KdTree(object): def __init__(self): """ j 主要是一个递增值,用来计算当前使用哪个维度进行切面 :return: """ self.j = 0 def __get_kd_tree(self, samples, k, parent_node): """ 生成kd树,主要采用递归 :param samples: 样本[[(1,2,3),'A'],[(2,3,4),'b']] :param k: 样本的维度 :param parent_node: 父节点 :return: """ if samples is None or len(samples) == 0: return None #计算切面 l = self.j % k self.j = self.j + 1 #对样本进行排序,并取中位数 samples.sort(key=lambda s:s[0][l]) len_sam = len(samples) mid_index = len_sam / 2 mid_value = samples[mid_index][0][l] i = 0 while i < len_sam and samples[i][0][l] < mid_value: i += 1 #将中位数对应的样本设置为当前节点 root_node = Node(samples[i][0], samples[i][1]) root_node.parent = parent_node root_node.split = l if 0 == i: left_samples = [] else: left_samples = samples[0:i] if i >= len_sam - 1: right_samples = [] else: right_samples = samples[i+1:] #del(samples[mid_index]) root_node.left_node = self.__get_kd_tree(left_samples, k, root_node) root_node.right_node = self.__get_kd_tree(right_samples, k, root_node) return root_node def get_kd_tree(self, samples): """ :param samples: [[(1,2,3),'A'],[(2,3,4),'b']] :return: """ return self.__get_kd_tree(samples, len(samples[0][0]), None) def cal_dist(self, target, sample): """ 欧拉距离 :param target: 目标样本 :param sample: 需要计算距离的样本 :return: """ dis = 0. for i in range(0, len(target.dim)): dis += (target.dim[i] - sample.dim[i]) ** 2 dis = dis ** 0.5 return dis def __insert_heap(self, k, dis, node, heap): """ python 的 heap 是小顶堆, 将数值设置为负数,就变成了大顶堆 [(-dis,node)] :param k: :param dis: :param node: :param heap: :return: """ if len(heap) < k: heapq.heappush(heap, (-dis, node)) else: d = - heap[0][0] if dis < d: heapq.heapreplace(heap, (-dis, node)) def get_k_neighbors(self, target, root, k, heap): """ :param root: kd树的根节点 :param k: 邻居个数 :return: """ if root is None: return s = root.split if target.dim[s] < root.dim[s]: self.get_k_neighbors(target, root.left_node, k, heap) else: self.get_k_neighbors(target, root.right_node, k, heap) dis = self.cal_dist(target, root) self.__insert_heap(k, dis, root, heap) if root.parent is not None: father_s = root.parent.split check_node = None if root.dim[father_s] < root.parent.dim[father_s]: check_node = root.parent.right_node else: check_node = root.parent.left_node smallest = heapq.nlargest(1, heap) if check_node is not None and self.cal_dist(target, check_node) < -smallest[0][0]: self.get_k_neighbors(target, check_node, k, heap) else: return def get_label_of_sample(self, heap): lable_dict = {} for i in range(0, len(heap)): node_label = heap[i][1].label if lable_dict.has_key(node_label): lable_dict[node_label] = lable_dict[node_label] + 1 else: lable_dict[node_label] = 1 max = 0 max_label = '' for key in lable_dict.keys(): if lable_dict[key] > max: max = lable_dict[key] max_label = key return max_label if __name__ == '__main__': samples = [[(2,3),"A"],[(5,4),"B"],[(9,6),"C"],[(4,7),"D"],[(8,1),"E"],[(7,2),"F"]] kd = KdTree() kd_root = kd.get_kd_tree(samples) print kd_root.dim heap = [] target_node = Node((2.1, 3.1), "P") kd.get_k_neighbors(target_node, kd_root, 2, heap) print heap[0][1].dim print heap[1][1].dim print kd.get_label_of_sample(heap) print samples
实现了后,利用上面博客给的手写数据集进行了下测试
# -*- coding: utf-8 -*- import os import numpy as np import kd_tree class KnnDigits(object): def __init__(self): pass def img2array(self, filename): """ :return: """ rows = 32 cols = 32 img_array = np.zeros(rows * cols) with open(filename) as read_fp: for row in xrange(0, rows): line_str = read_fp.readline() for col in xrange(0, cols): img_array[row * rows + col] = int(line_str[col]) #img_array[row] += int(line_str[col]) return img_array def load_data(self, data_dir): """ :param data_dir: :return: """ samples = [] files_list = os.listdir(data_dir) num_samples = len(files_list) for i in xrange(0, num_samples): file_name = os.path.join(data_dir, files_list[i]) img_array = self.img2array(file_name) img_label = int(files_list[i].split('_')[0]) samples.append([img_array, img_label]) return samples def run(self, train_dir, test_dir): """ :param train_dir: :param test_dir: :return: """ train_samples = self.load_data(train_dir) test_samples = self.load_data(test_dir) kd = kd_tree.KdTree() kd_root = kd.get_kd_tree(train_samples) nums_test_samples = len(test_samples) match_count = 0 for i in range(0, nums_test_samples): heap = [] target_node = kd_tree.Node(test_samples[i][0], test_samples[i][1]) kd.get_k_neighbors(target_node, kd_root, 3, heap) pridict_label = kd.get_label_of_sample(heap) if pridict_label == test_samples[i][1]: match_count += 1 #print "pridict label is %s and test lable is %s" %(pridict_label, test_samples[i][1]) accur = float(match_count) / nums_test_samples return accur if __name__ == '__main__': train_dir = "/Users/baidu/PycharmProjects/statistics_learning_method/digits/trainingDigits" test_dir = "/Users/baidu/PycharmProjects/statistics_learning_method/digits/testDigits" knn = KnnDigits() print knn.run(train_dir, test_dir)
我的娘亲哟,结果只有0.879492600423
这说明kd 树实现的不好,并且生成的树不平衡,并且很可能有BUG。
算法改进:
http://my.oschina.net/keyven/blog/221792