zoukankan      html  css  js  c++  java
  • KNN算法

    KNN算法的介绍请参考:

    http://blog.csdn.net/zouxy09/article/details/16955347

    统计学习方法里面给出了KD Tree的算法介绍,按照书上的进行了实现:

    # -*- coding: utf-8 -*-
    
    from operator import itemgetter
    from copy import deepcopy
    import heapq
    
    
    class Node(object):
        def __init__(self, dim, label=None, parent = None,
                     split = 0):
            """
            kd树的节点
            :param dim: 节点的向量
            :param label: 节点的label
            :param parent: 父节点
            :param split: 第 split 进行切分
            :return:
            """
            self.dim = deepcopy(dim)
            self.label = deepcopy(label)
            self.left_node = None
            self.right_node = None
            self.parent = parent
            self.split = split
    
    
    class KdTree(object):
    
        def __init__(self):
            """
            j 主要是一个递增值,用来计算当前使用哪个维度进行切面
            :return:
            """
            self.j = 0
    
        def __get_kd_tree(self, samples, k, parent_node):
            """
            生成kd树,主要采用递归
            :param samples: 样本[[(1,2,3),'A'],[(2,3,4),'b']]
            :param k: 样本的维度
            :param parent_node: 父节点
            :return:
            """
    
            if samples is None or len(samples) == 0:
                return None
    
            #计算切面
            l = self.j % k
            self.j = self.j + 1
    
            #对样本进行排序,并取中位数
            samples.sort(key=lambda s:s[0][l])
            len_sam = len(samples)
            mid_index = len_sam / 2
            mid_value = samples[mid_index][0][l]
            i = 0
            while i < len_sam and samples[i][0][l] < mid_value:
                i += 1
    
            #将中位数对应的样本设置为当前节点
            root_node = Node(samples[i][0], samples[i][1])
            root_node.parent = parent_node
            root_node.split = l
            if 0 == i:
                left_samples = []
            else:
                left_samples = samples[0:i]
            if i >= len_sam - 1:
                right_samples = []
            else:
                right_samples = samples[i+1:]
    
            #del(samples[mid_index])
            root_node.left_node = self.__get_kd_tree(left_samples, k, root_node)
            root_node.right_node = self.__get_kd_tree(right_samples, k, root_node)
    
            return root_node
    
        def get_kd_tree(self, samples):
            """
            :param samples: [[(1,2,3),'A'],[(2,3,4),'b']]
            :return:
            """
            return  self.__get_kd_tree(samples, len(samples[0][0]), None)
    
        def cal_dist(self, target, sample):
            """
            欧拉距离
            :param target: 目标样本
            :param sample: 需要计算距离的样本
            :return:
            """
    
            dis = 0.
    
            for i in range(0, len(target.dim)):
               dis += (target.dim[i] - sample.dim[i]) ** 2
            dis = dis ** 0.5
    
            return dis
    
    
        def __insert_heap(self, k, dis, node, heap):
            """
            python 的 heap 是小顶堆, 将数值设置为负数,就变成了大顶堆
            [(-dis,node)]
            :param k:
            :param dis:
            :param node:
            :param heap:
            :return:
            """
            if len(heap) < k:
                heapq.heappush(heap, (-dis, node))
            else:
                d  = - heap[0][0]
                if dis < d:
                    heapq.heapreplace(heap, (-dis, node))
    
        def get_k_neighbors(self, target, root, k, heap):
            """
    
            :param root: kd树的根节点
            :param k: 邻居个数
            :return:
            """
            if root is None:
                return
    
            s = root.split
            if target.dim[s] < root.dim[s]:
                self.get_k_neighbors(target, root.left_node, k, heap)
            else:
                 self.get_k_neighbors(target, root.right_node, k, heap)
    
            dis = self.cal_dist(target, root)
            self.__insert_heap(k, dis, root, heap)
    
            if root.parent is not None:
                father_s = root.parent.split
                check_node = None
                if root.dim[father_s] < root.parent.dim[father_s]:
                    check_node = root.parent.right_node
                else:
                    check_node = root.parent.left_node
    
                smallest = heapq.nlargest(1, heap)
                if check_node is not None and self.cal_dist(target, check_node) < -smallest[0][0]:
                    self.get_k_neighbors(target, check_node, k, heap)
                else:
                    return
    
        def get_label_of_sample(self, heap):
            lable_dict = {}
            for i in range(0, len(heap)):
                node_label = heap[i][1].label
                if lable_dict.has_key(node_label):
                    lable_dict[node_label] = lable_dict[node_label] + 1
                else:
                    lable_dict[node_label] = 1
    
            max = 0
            max_label = ''
            for key in lable_dict.keys():
                if lable_dict[key] > max:
                    max = lable_dict[key]
                    max_label = key
            return max_label
    
    
    if __name__ == '__main__':
        samples = [[(2,3),"A"],[(5,4),"B"],[(9,6),"C"],[(4,7),"D"],[(8,1),"E"],[(7,2),"F"]]
        kd = KdTree()
        kd_root = kd.get_kd_tree(samples)
        print kd_root.dim
        heap = []
        target_node = Node((2.1, 3.1), "P")
        kd.get_k_neighbors(target_node, kd_root, 2, heap)
        print heap[0][1].dim
        print heap[1][1].dim
        print kd.get_label_of_sample(heap)
        print samples
    kd tree

    实现了后,利用上面博客给的手写数据集进行了下测试

    # -*- coding: utf-8 -*-
    
    import os
    
    import numpy as np
    
    import kd_tree
    
    class KnnDigits(object):
    
        def __init__(self):
            pass
    
        def img2array(self, filename):
            """
    
            :return:
            """
            rows = 32
            cols = 32
    
            img_array = np.zeros(rows * cols)
    
            with open(filename) as read_fp:
                for row in xrange(0, rows):
                    line_str = read_fp.readline()
                    for col in xrange(0, cols):
                        img_array[row * rows + col] = int(line_str[col])
                        #img_array[row] += int(line_str[col])
            return img_array
    
        def load_data(self, data_dir):
            """
    
            :param data_dir:
            :return:
            """
            samples = []
            files_list = os.listdir(data_dir)
            num_samples = len(files_list)
            for i in xrange(0, num_samples):
                file_name = os.path.join(data_dir, files_list[i])
                img_array = self.img2array(file_name)
                img_label = int(files_list[i].split('_')[0])
                samples.append([img_array, img_label])
    
            return samples
    
        def run(self, train_dir, test_dir):
            """
    
            :param train_dir:
            :param test_dir:
            :return:
            """
            train_samples = self.load_data(train_dir)
            test_samples = self.load_data(test_dir)
            kd = kd_tree.KdTree()
            kd_root = kd.get_kd_tree(train_samples)
            nums_test_samples = len(test_samples)
            match_count = 0
            for i in range(0, nums_test_samples):
                heap = []
                target_node = kd_tree.Node(test_samples[i][0], test_samples[i][1])
                kd.get_k_neighbors(target_node, kd_root, 3, heap)
                pridict_label = kd.get_label_of_sample(heap)
                if pridict_label == test_samples[i][1]:
                    match_count += 1
                #print "pridict label is %s and test lable is %s" %(pridict_label, test_samples[i][1])
            accur = float(match_count) / nums_test_samples
    
            return accur
    
    if __name__ == '__main__':
        train_dir = "/Users/baidu/PycharmProjects/statistics_learning_method/digits/trainingDigits"
        test_dir = "/Users/baidu/PycharmProjects/statistics_learning_method/digits/testDigits"
    
        knn = KnnDigits()
        print knn.run(train_dir, test_dir)
    View Code

    我的娘亲哟,结果只有0.879492600423

    这说明kd 树实现的不好,并且生成的树不平衡,并且很可能有BUG。

    算法改进:

    http://my.oschina.net/keyven/blog/221792

  • 相关阅读:
    EasyDSS功能简介视频直播、直播鉴权(如何完美将EasyDSS过渡到新版)
    EasyNVR前端构建之输入框样式的调整
    NVR硬件录像机web无插件播放方案(支持取特定时间段视频流)
    Windows操作系统远程Linux服务器传输文件方法(以EasyDSS云平台、EasyNVR上传部署为例)
    零基础实现摄像头的全平台直播 (二)公网直播的实现
    海康、大华NVR硬件录像机录像无插件全平台访问实现播放时间轴实现
    直播与虚拟直播
    CF585EPresent for Vitalik the Philatelist【莫比乌斯反演,狄利克雷前缀和】
    AT4519[AGC032D]Rotation Sort【dp】
    P5110块速递推【特征方程,分块】
  • 原文地址:https://www.cnblogs.com/SpeakSoftlyLove/p/5229346.html
Copyright © 2011-2022 走看看