zoukankan      html  css  js  c++  java
  • 2. KNN和KdTree算法实现

    1. K近邻算法(KNN)

    2. KNN和KdTree算法实现

    1. 前言

    KNN一直是一个机器学习入门需要接触的第一个算法,它有着简单,易懂,可操作性强的一些特点。今天我久带领大家先看看sklearn中KNN的使用,在带领大家实现出自己的KNN算法。

    2. KNN在sklearn中的使用

    knn在sklearn中是放在sklearn.neighbors的包中的,我们今天主要介绍KNeighborsClassifier的分类器。

    KNeighborsClassifier的主要参数是:

    参数 意义
    n_neighbors K值的选择与样本分布有关,一般选择一个较小的K值,可以通过交叉验证来选择一个比较优的K值,默认值是5
    weights ‘uniform’是每个点权重一样,‘distance’则权重和距离成反比例,即距离预测目标更近的近邻具有更高的权重
    algorithm ‘brute’对应第一种蛮力实现,‘kd_tree’对应第二种KD树实现,‘ball_tree’对应第三种的球树实现, ‘auto’则会在上面三种算法中做权衡,选择一个拟合最好的最优算法。
    leaf_size 这个值控制了使用KD树或者球树时, 停止建子树的叶子节点数量的阈值。
    metric K近邻法和限定半径最近邻法类可以使用的距离度量较多,一般来说默认的欧式距离(即p=2的闵可夫斯基距离)就可以满足我们的需求。
    p p是使用距离度量参数 metric 附属参数,只用于闵可夫斯基距离和带权重闵可夫斯基距离中p值的选择,p=1为曼哈顿距离, p=2为欧式距离。默认为2

    我个人认为这些个参数,比较重要的应该属n_neighbors、weights了,其他默认的也都没太大问题。

    3. KNN基础版实现

    直接看代码如下,完整代码GitHub

    def fit(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train
    
    def predict(self, X):
        # 取出n个点
        knn_list = []
        for i in range(self.n):
            dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
            knn_list.append((dist, self.y_train[i]))
    
        for i in range(self.n, len(self.X_train)):
            max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))
            dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
            if knn_list[max_index][0] > dist:
                knn_list[max_index] = (dist, self.y_train[i])
    
        # 统计
        knn = [k[-1] for k in knn_list]
        return Counter(knn).most_common()[0][0]
    

    我的接口设计都是按照sklearn的样子设计的,fit方法其实主要用来接收参数了,没有进行任何的处理。所有的操作都在predict中,着就会导致,我们对每个点预测的时候,时间消耗比较大。这个基础版本大家看看就好,我想大家自己去写,肯定也没问题的。

    4. KdTree版本实现

    kd树算法包括三步,第一步是建树,第二部是搜索最近邻,最后一步是预测。

    4.1 构建kd树

    kd树是一种对n维空间的实例点进行存储,以便对其进行快速检索的树形结构。kd树是二叉树,构造kd树相当于不断的用垂直于坐标轴的超平面将n维空间进行划分,构成一系列的n维超矩阵区域。

    下面的流程图更加清晰的描述了kd树的构建过程:

    image

    kdtree树的生成代码:

    # 建立kdtree
    def create(self, dataSet, label, depth=0):
        if len(dataSet) > 0:
            m, n = np.shape(dataSet)
            self.n = n
            axis = depth % self.n
            mid = int(m / 2)
            dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
            node = Node(dataSetcopy[mid], label[mid], depth)
            if depth == 0:
                self.KdTree = node
            node.lchild = self.create(dataSetcopy[:mid], label, depth+1)
            node.rchild = self.create(dataSetcopy[mid+1:], label, depth+1)
            return node
        return None
    

    4.2 kd树搜索最近邻和预测

    当我们生成kd树以后,就可以去预测测试集里面的样本目标点了。预测的过程如下:

    1. 对于一个目标点,我们首先在kd树里面找到包含目标点的叶子节点。以目标点为圆心,以目标点到叶子节点样本实例的距离为半径,得到一个超球体,最近邻的点一定在这个超球体内部。
    2. 然后返回叶子节点的父节点,检查另一个子节点包含的超矩形体是否和超球体相交,如果相交就到这个子节点寻找是否有更加近的近邻,有的话就更新最近邻,并且更新超球体。如果不相交那就简单了,我们直接返回父节点的父节点,在另一个子树继续搜索最近邻。
    3. 当回溯到根节点时,算法结束,此时保存的最近邻节点就是最终的最近邻。
      kdtree树的搜索代码:
    # 搜索kdtree的前count个近的点
    def search(self, x, count = 1):
        nearest = []
        for i in range(count):
            nearest.append([-1, None])
        # 初始化n个点,nearest是按照距离递减的方式
        self.nearest = np.array(nearest)
    
        def recurve(node):
            if node is not None:
                # 计算当前点的维度axis
                axis = node.depth % self.n
                # 计算测试点和当前点在axis维度上的差
                daxis = x[axis] - node.data[axis]
                # 如果小于进左子树,大于进右子树
                if daxis < 0:
                    recurve(node.lchild)
                else:
                    recurve(node.rchild)
                # 计算预测点x到当前点的距离dist
                dist = np.sqrt(np.sum(np.square(x - node.data)))
                for i, d in enumerate(self.nearest):
                    # 如果有比现在最近的n个点更近的点,更新最近的点
                    if d[0] < 0 or dist < d[0]:
                        # 插入第i个位置的点
                        self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
                        # 删除最后一个多出来的点
                        self.nearest = self.nearest[:-1]
                        break
    
                # 统计距离为-1的个数n
                n = list(self.nearest[:, 0]).count(-1)
                '''
                self.nearest[-n-1, 0]是当前nearest中已经有的最近点中,距离最大的点。
                self.nearest[-n-1, 0] > abs(daxis)代表以x为圆心,self.nearest[-n-1, 0]为半径的圆与axis
                相交,说明在左右子树里面有比self.nearest[-n-1, 0]更近的点
                '''
                if self.nearest[-n-1, 0] > abs(daxis):
                    if daxis < 0:
                        recurve(node.rchild)
                    else:
                        recurve(node.lchild)
    
        recurve(self.KdTree)
    
        # nodeList是最近n个点的
        nodeList = self.nearest[:, 1]
    
        # knn是n个点的标签
        knn = [node.label for node in nodeList]
        return self.nearest[:, 1], Counter(knn).most_common()[0][0]
    

    这段代码其实比较好的实现了上面搜索的思想。如果读者对递归的过程想不太清楚,可以画下图,或者debug下我完整的代码GitHub

    5. 总结

    本文实现了KNN的基础版和KdTree版本,读者可以去尝试下ballTree的版本,理论上效率比KdTree还要好一些。

  • 相关阅读:
    Java实现 LeetCode 50 Pow(x,n)
    Java实现 LeetCode 50 Pow(x,n)
    Java实现 LeetCode 49 字母异位词分组
    Java实现 LeetCode 49 字母异位词分组
    Java实现 LeetCode 49 字母异位词分组
    Java实现 LeetCode 48 旋转图像
    Java实现 LeetCode 48 旋转图像
    Java实现 LeetCode 48 旋转图像
    Java实现 LeetCode 47 全排列 II(二)
    Java实现 LeetCode 47 全排列 II(二)
  • 原文地址:https://www.cnblogs.com/huangyc/p/10294307.html
Copyright © 2011-2022 走看看