zoukankan      html  css  js  c++  java
  • 2. KNN和KdTree算法实现

    1. K近邻算法(KNN)

    2. KNN和KdTree算法实现

    1. 前言

    KNN一直是一个机器学习入门需要接触的第一个算法,它有着简单,易懂,可操作性强的一些特点。今天我久带领大家先看看sklearn中KNN的使用,在带领大家实现出自己的KNN算法。

    2. KNN在sklearn中的使用

    knn在sklearn中是放在sklearn.neighbors的包中的,我们今天主要介绍KNeighborsClassifier的分类器。

    KNeighborsClassifier的主要参数是:

    参数 意义
    n_neighbors K值的选择与样本分布有关,一般选择一个较小的K值,可以通过交叉验证来选择一个比较优的K值,默认值是5
    weights ‘uniform’是每个点权重一样,‘distance’则权重和距离成反比例,即距离预测目标更近的近邻具有更高的权重
    algorithm ‘brute’对应第一种蛮力实现,‘kd_tree’对应第二种KD树实现,‘ball_tree’对应第三种的球树实现, ‘auto’则会在上面三种算法中做权衡,选择一个拟合最好的最优算法。
    leaf_size 这个值控制了使用KD树或者球树时, 停止建子树的叶子节点数量的阈值。
    metric K近邻法和限定半径最近邻法类可以使用的距离度量较多,一般来说默认的欧式距离(即p=2的闵可夫斯基距离)就可以满足我们的需求。
    p p是使用距离度量参数 metric 附属参数,只用于闵可夫斯基距离和带权重闵可夫斯基距离中p值的选择,p=1为曼哈顿距离, p=2为欧式距离。默认为2

    我个人认为这些个参数,比较重要的应该属n_neighbors、weights了,其他默认的也都没太大问题。

    3. KNN基础版实现

    直接看代码如下,完整代码GitHub

    def fit(self, X_train, y_train):
        self.X_train = X_train
        self.y_train = y_train
    
    def predict(self, X):
        # 取出n个点
        knn_list = []
        for i in range(self.n):
            dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
            knn_list.append((dist, self.y_train[i]))
    
        for i in range(self.n, len(self.X_train)):
            max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))
            dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
            if knn_list[max_index][0] > dist:
                knn_list[max_index] = (dist, self.y_train[i])
    
        # 统计
        knn = [k[-1] for k in knn_list]
        return Counter(knn).most_common()[0][0]
    

    我的接口设计都是按照sklearn的样子设计的,fit方法其实主要用来接收参数了,没有进行任何的处理。所有的操作都在predict中,着就会导致,我们对每个点预测的时候,时间消耗比较大。这个基础版本大家看看就好,我想大家自己去写,肯定也没问题的。

    4. KdTree版本实现

    kd树算法包括三步,第一步是建树,第二部是搜索最近邻,最后一步是预测。

    4.1 构建kd树

    kd树是一种对n维空间的实例点进行存储,以便对其进行快速检索的树形结构。kd树是二叉树,构造kd树相当于不断的用垂直于坐标轴的超平面将n维空间进行划分,构成一系列的n维超矩阵区域。

    下面的流程图更加清晰的描述了kd树的构建过程:

    image

    kdtree树的生成代码:

    # 建立kdtree
    def create(self, dataSet, label, depth=0):
        if len(dataSet) > 0:
            m, n = np.shape(dataSet)
            self.n = n
            axis = depth % self.n
            mid = int(m / 2)
            dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
            node = Node(dataSetcopy[mid], label[mid], depth)
            if depth == 0:
                self.KdTree = node
            node.lchild = self.create(dataSetcopy[:mid], label, depth+1)
            node.rchild = self.create(dataSetcopy[mid+1:], label, depth+1)
            return node
        return None
    

    4.2 kd树搜索最近邻和预测

    当我们生成kd树以后,就可以去预测测试集里面的样本目标点了。预测的过程如下:

    1. 对于一个目标点,我们首先在kd树里面找到包含目标点的叶子节点。以目标点为圆心,以目标点到叶子节点样本实例的距离为半径,得到一个超球体,最近邻的点一定在这个超球体内部。
    2. 然后返回叶子节点的父节点,检查另一个子节点包含的超矩形体是否和超球体相交,如果相交就到这个子节点寻找是否有更加近的近邻,有的话就更新最近邻,并且更新超球体。如果不相交那就简单了,我们直接返回父节点的父节点,在另一个子树继续搜索最近邻。
    3. 当回溯到根节点时,算法结束,此时保存的最近邻节点就是最终的最近邻。
      kdtree树的搜索代码:
    # 搜索kdtree的前count个近的点
    def search(self, x, count = 1):
        nearest = []
        for i in range(count):
            nearest.append([-1, None])
        # 初始化n个点,nearest是按照距离递减的方式
        self.nearest = np.array(nearest)
    
        def recurve(node):
            if node is not None:
                # 计算当前点的维度axis
                axis = node.depth % self.n
                # 计算测试点和当前点在axis维度上的差
                daxis = x[axis] - node.data[axis]
                # 如果小于进左子树,大于进右子树
                if daxis < 0:
                    recurve(node.lchild)
                else:
                    recurve(node.rchild)
                # 计算预测点x到当前点的距离dist
                dist = np.sqrt(np.sum(np.square(x - node.data)))
                for i, d in enumerate(self.nearest):
                    # 如果有比现在最近的n个点更近的点,更新最近的点
                    if d[0] < 0 or dist < d[0]:
                        # 插入第i个位置的点
                        self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
                        # 删除最后一个多出来的点
                        self.nearest = self.nearest[:-1]
                        break
    
                # 统计距离为-1的个数n
                n = list(self.nearest[:, 0]).count(-1)
                '''
                self.nearest[-n-1, 0]是当前nearest中已经有的最近点中,距离最大的点。
                self.nearest[-n-1, 0] > abs(daxis)代表以x为圆心,self.nearest[-n-1, 0]为半径的圆与axis
                相交,说明在左右子树里面有比self.nearest[-n-1, 0]更近的点
                '''
                if self.nearest[-n-1, 0] > abs(daxis):
                    if daxis < 0:
                        recurve(node.rchild)
                    else:
                        recurve(node.lchild)
    
        recurve(self.KdTree)
    
        # nodeList是最近n个点的
        nodeList = self.nearest[:, 1]
    
        # knn是n个点的标签
        knn = [node.label for node in nodeList]
        return self.nearest[:, 1], Counter(knn).most_common()[0][0]
    

    这段代码其实比较好的实现了上面搜索的思想。如果读者对递归的过程想不太清楚,可以画下图,或者debug下我完整的代码GitHub

    5. 总结

    本文实现了KNN的基础版和KdTree版本,读者可以去尝试下ballTree的版本,理论上效率比KdTree还要好一些。

  • 相关阅读:
    JDBC存取二进制文件示例
    java多线程向数据库中加载数据
    Lucene建索引代码
    postgresql存储二进制大数据文件
    java项目使用Echarts 做柱状堆叠图,包含点击事件
    子页面获取父页面控件
    JSTL和select标签的组合使用
    log4j配置祥解
    IT项目经理应具备的十大软技能
    Spring和Struct整合的三个方法
  • 原文地址:https://www.cnblogs.com/huangyc/p/10294307.html
Copyright © 2011-2022 走看看