zoukankan      html  css  js  c++  java
  • KNN原理和实现

    1. K近邻算法原理

    a. k近邻算法是一种基本的分类与回归方法

      分类问题: 对于新的样本,根据其k个最近邻的训练样本的标签,通过多数表决的方式进行预测

      回归问题: 对于新的样本,根据其k个最近邻的训练样本标签值的均值作为预测值

    b. k近邻算法不具有显示的学习过程,属于直接预测,是惰性学习的代表

    c. k近邻算法是一个非参数学习算法,没有任何的参数(k属于超参数,不是学习的参数)

      k近邻算法具有很高的容量,训练的样本数量比较大时能够获得较高的精度

      缺点:计算成本高,需要构建一个N*N的距离举证,计算量为O(N*N),其中N为训练样本的数量

            当数据集是几十个亿样本时,计算量不能够接收

         数据集很小时,泛化能力很差,容易过拟合

           无法判断特征的重要性

    d. k近邻算法的三个重要因素

      k值的选择

      距离度量

      决策规则

    2. k值的选择

    a. 当k值为1时,为最近邻算法,将训练集中与输入最近的点的类别作为输入点的分类

    b. 当k的值较小,相当于用较小领域中的训练样本进行预测,学习的偏差较小

      当近邻的点恰好为噪声,预测会出错,K值减小意味这模型整体变复杂,容易出现过拟合

      优点: 减小学习的偏差

      缺点: 增大学习的方差(波动较大)

    c. 当k的值较大,相当于用较大的领域中的训练样本进行预测

      输入样本较远的训练样本也会对预测起作用,使预测偏离预期的结果

      优点: 减小学习的方差(波动较小)

      缺点: 增大学习的偏差

    d. 应用中,k一般选区一个较小的数值,通过交叉验证法来选取最优的k值

    3. 距离度量

    a. 特征空间中两个样本之间的距离是两个样本相似程度的反应

      k近邻模型中一般是n维实数向量空间,距离一般选取欧式距离

    b. 不同的距离度量确定的最近邻点是不同的

    4. 决策规则

    a. 分类决策的规则

    分类决策通常采用多数表决,也可以基于距离的远近进行加权投票,距离越近的样本权重越大

    多数表决等价于经验风险最小化

    b. 回归决策规则

    回归决策通常采用均值回归,也可以基于距离的远近进行加权投票,距离越近的样本权重越大

    均值回归等价于经验风险最小化

    5. kd树:对训练数据进行快速的k近邻搜索

    a. 实现k近邻算法时,主要考虑的问题是如何快速的对训练数据进行k近邻搜索

    b. 最简单的方法:线性扫描(强制破解),计算输入样本与每一个训练样本之间的距离

    c. kd树是一种对k维空间的样本进行存储以便进行快速搜索的树形数据结构,它是一个二叉树,表示对k维空间的一个划分

    d. 构建kd树的过程相当于不断的用垂直于坐标轴的超平面对k维空间切分的过程,kd树的每一个节点对应于一个k维超矩形区域

    6. kd树的构建算法

    步骤1: 以x1为轴,样本集中所有样本的x1的中位数x1*为切分点,将根节点的超矩形切分成两个子区域,切分产生深度为1的左右子节点。左子节点对应x1 < x1*的区域,右子节点对应坐标x1 > x1*的子区域,落在切分超平面的点存在根节点

    步骤2:对深度为j的节点,选择xl为切分的坐标轴继续切分,l = j(mod k) + 1,切分后,树的深度为j + 1

    步骤3:直到所有的节点的两个子域中没有样本存在时,切分停止,形成kd树的区域划分

    7. kd树的搜索算法

    步骤1: 在kd树中找到包含测试点的叶节点,从根节点出发,递归访问Kd树

      若测试点当前维度坐标小于切分点的坐标,则查找当前节点的左子节点

      若测试点当前纬度的坐标大于切分点的坐标,则查找当前节点的右子节点

      在访问的过程中记录访问的各个节点的顺序,存放在先进后出的队列中,以便后面的回退

    步骤2: 以此叶节点为当前最近子节点Xnst,真实的最近点一定在测试点与当前最近子节点构成的超球体内,测试点为球心

    步骤3:从队列中弹出节点,设为Xinew(每次回退都是回退到kd树的父节点)

      若Xinew比Xnst更近,则更新Xnst

      考察节点Xinew所在的超平面与以测试点为球心,以测试点到Xnst的距离为半径的超球体是否相交

        相交:

          若测试点是Xinew的左子节点,则进入Xinew的右子节点,然后进行向下的所有并更新队列Queue,然后向上回退

          若测试点是Xinew的右子节点,则进入Xinew的左子节点,然后进行向下搜索并更新队列Queue,然后向上回退

        不相交:直接回退

    步骤4:当回退到根节点时,搜索结束,最后的当前最近点即为测试点的最近邻点

    kd树的搜索的平均计算复杂度为O(log N),N为训练集的大小

    通常最近邻搜索只需要检测最近几个叶节点即可

    8. KNN实现

    函数原型:(https://scikit-learn.org/dev/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn.neighbors.KNeighborsClassifier)

    class sklearn.neighbors.KNeighborsClassifier(n_neighbors=5, weights='uniform', algorithm='auto', leaf_size=30, p=2, metric='minkowski', metric_params=None, n_jobs=None, **kwargs)

    n_neighbors int, optional (default = 5)

    Number of neighbors to use by default for kneighbors queries.

    weights str or callable, optional (default = ‘uniform’)

    weight function used in prediction. Possible values:

    • ‘uniform’ : uniform weights. All points in each neighborhood are weighted equally.
    • ‘distance’ : weight points by the inverse of their distance. in this case, closer neighbors of a query point will have a greater influence than neighbors which are further away.
    • [callable] : a user-defined function which accepts an array of distances, and returns an array of the same shape containing the weights.

    algorithm {‘auto’, ‘ball_tree’, ‘kd_tree’, ‘brute’}, optional

    Algorithm used to compute the nearest neighbors:

    • ‘ball_tree’ will use BallTree
    • ‘kd_tree’ will use KDTree
    • ‘brute’ will use a brute-force search.
    • ‘auto’ will attempt to decide the most appropriate algorithm based on the values passed to fitmethod.

    Note: fitting on sparse input will override the setting of this parameter, using brute force.

    leaf_size int, optional (default = 30)

    Leaf size passed to BallTree or KDTree. This can affect the speed of the construction and query, as well as the memory required to store the tree. The optimal value depends on the nature of the problem.

    p integer, optional (default = 2)

    Power parameter for the Minkowski metric. When p = 1, this is equivalent to using manhattan_distance (l1), and euclidean_distance (l2) for p = 2. For arbitrary p, minkowski_distance (l_p) is used.

    metric string or callable, default ‘minkowski’

    the distance metric to use for the tree. The default metric is minkowski, and with p=2 is equivalent to the standard Euclidean metric. See the documentation of the DistanceMetric class for a list of available metrics.

    metric_params dict, optional (default = None)

    Additional keyword arguments for the metric function.

    n_jobs int or None, optional (default=None)

    The number of parallel jobs to run for neighbors search. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors. See Glossary for more details. Doesn’t affect fit method.

    分类实例:

    # Importing the libraries
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    
    # Importing the dataset
    dataset = pd.read_csv('../datasets/Social_Network_Ads.csv')
    X = dataset.iloc[:, [2, 3]].values
    y = dataset.iloc[:, 4].values
    
    # Splitting the dataset into the Training set and Test set
    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.25, random_state = 0)
    
    # Feature Scaling
    from sklearn.preprocessing import StandardScaler
    sc = StandardScaler()
    X_train = sc.fit_transform(X_train)
    X_test = sc.transform(X_test)
    
    # Fitting K-NN to the Training set
    from sklearn.neighbors import KNeighborsClassifier
    classifier = KNeighborsClassifier(n_neighbors = 5, metric = 'minkowski', p = 2)
    classifier.fit(X_train, y_train)
    
    # Predicting the Test set results
    y_pred = classifier.predict(X_test)
    
    # Making the Confusion Matrix
    from sklearn.metrics import confusion_matrix
    from sklearn.metrics import classification_report
    cm = confusion_matrix(y_test, y_pred)
    print(cm)
    print(classification_report(y_test, y_pred))

     回归:https://scikit-learn.org/dev/modules/generated/sklearn.neighbors.KNeighborsRegressor.html#sklearn.neighbors.KNeighborsRegressor

    9. KDTree

    https://scikit-learn.org/dev/modules/generated/sklearn.neighbors.KDTree.html#sklearn.neighbors.KDTree

    KDTree(X, leaf_size=40, metric=’minkowski’, **kwargs)

    X array-like, shape = [n_samples, n_features]

    n_samples is the number of points in the data set, and n_features is the dimension of the parameter space. Note: if X is a C-contiguous array of doubles then data will not be copied. Otherwise, an internal copy will be made.

    leaf_size positive integer (default = 40)

    Number of points at which to switch to brute-force. Changing leaf_size will not affect the results of a query, but can significantly impact the speed of a query and the memory required to store the constructed tree. The amount of memory needed to store the tree scales as approximately n_samples / leaf_size. For a specified leaf_size, a leaf node is guaranteed to satisfy leaf_size <= n_points <= leaf_size, except in the case that n_samples leaf_size.

    metric string or DistanceMetric object

    the distance metric to use for the tree. Default=’minkowski’ with p=2 (that is, a euclidean metric). See the documentation of the DistanceMetric class for a list of available metrics. kd_tree.valid_metrics gives a list of the metrics which are valid for KDTree.

    Additional keywords are passed to the distance metric class.

    测试:

    >>> import numpy as np
    >>> rng = np.random.RandomState(0)
    >>> X = rng.random_sample((10, 3))  # 10 points in 3 dimensions
    >>> tree = KDTree(X, leaf_size=2)              
    >>> dist, ind = tree.query(X[:1], k=3)    # 1表示查询X第0行到其它行元素的距离,3表示查询3个最近的距离            
    >>> print(ind)  # indices of 3 closest neighbors
    [0 3 1]
    >>> print(dist)  # distances to 3 closest neighbors
    [ 0.          0.19662693  0.29473397]

    还可以计算点到指定半径距离内有哪些点。

  • 相关阅读:
    福建工程学院第十四届ACM校赛B题题解
    2018 ACM-ICPC青岛现场赛 B题 Kawa Exam 题解 ZOJ 4059
    联合周赛第二场 我在哪?题解
    维修数列 Splay(这可能是我写过最麻烦的题之一了。。。用平衡树维护dp。。。丧心病狂啊。。。。)
    虚树入门!世界树!
    御坂御坂题解(出自北航校赛) 约瑟夫环问题高效解决方案
    网络流24题! 开始!题解!
    AFO
    【模板库】减维的模板库【停更】
    【组合数学】Educational Codeforces Round 83 (Rated for Div. 2) D题
  • 原文地址:https://www.cnblogs.com/feng-ying/p/11155436.html
Copyright © 2011-2022 走看看