zoukankan      html  css  js  c++  java
  • KNN算法

    一、KNN算法简述

      K近邻算法(kNN,k-NearestNeighbor)分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。

      K近邻算法的2个关键要素是:已标注样本量及其可靠性、距离。

      样本要求:

      对于给定的已标注的样本,理想状态下,我们认为:样本各类别在给定维度上可分。当一个样本,在其最近的k个样本中充斥了过多其它类别的样本时,会导致k近邻算法的准确率大大降低。这就要求各类别之间的样本量大小不宜过度失衡,且其分布应具有明显的可分性。

      距离要求:

      计算距离的方式,包括欧式距离、马氏距离。。。,必须要统一各个维度上的量纲。

      算法流程

      给定已知带分类标签的样本,以及待分类的未知样本
          1.计算已知类别数据集中的点与当前点之间的距离
          2.按照距离递增次序排序
          3.选取与当前点距离最小的k个点
          4.确定前k个点所在类别的出现频率
          5.返回前k个点出现频率最高的类别作为当前点的预测分类

    二、KNN算法实现

      1.python3实现KNN算法

      LoadDataSet类用于从指定网站上下载数据集(如果有变则不可用);Normalizer类用于归一化处理数据集;Plot类用于绘制三维散点图;KNN类用于实现k-近邻算法。

    import urllib, bs4
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D

    class
    LoadDataSet(object): def __init__(self): self.init_dataSet = [] def load_dataSet(self): """爬取海伦约会的部分已分类数据集""" file=urllib.request.urlopen('http://www.codeforge.com/read/252622/datingTestSet2.txt__html') html=file.read() # utf-8编码的字节流 btf = bs4.BeautifulSoup(html, "html.parser") # 解析器 node = btf.find("pre") # 找到节点 string = node.get_text() # 获取节点内的文本 rows = string.split(" ") dataSet = [] labels = [] for i, row in enumerate(rows): split = row.split(" ") if len(split) == 4 and len(split[0]) > 1: data = [int(split[0]), float(split[1]), float(split[2]), int(split[3])] dataSet.append(data[: 3]) labels.append(data[3]) self.init_dataSet.append(data) return dataSet, labels class Normalizer(object): def maxmin(self, dataSet): """最大值最小值归一法""" dataSet = np.array(dataSet) shape = dataSet.shape norm = np.zeros(shape) self.max_value = dataSet.max(0) # 求每一列的最小值,并组成一行向量 self.min_value = dataSet.min(0) # 求每一列的最大值,并组成一行向量 self.ranges = self.max_value - self.min_value # 同样是行 norm = dataSet - np.tile(self.min_value, (shape[0], 1)) # 将各列最小值组成的行展开成和dataSet同shape的数据集 return norm / np.tile(self.ranges, (shape[0], 1)) # 矩阵元素相除 class Plot(object): def plot(self, init_dataSet): init = np.array(init_dataSet) fig = plt.figure(figsize=(12, 8)) ax = fig.add_subplot(111, projection="3d") for i in range(1, 4): # 1,2,3分三组数据,绘到一张图例 arr = init[init[:, 3] == i] ax.scatter(xs=arr[:, 0], ys=arr[:, 1], zs=arr[:, 2], zdir="z", s=40) plt.show() class KNN(LoadDataSet, Normalizer, Plot): def __init__(self): super().__init__() def classify(self, row, k): """ row: 要分类的行向量 k: 最近邻点数 """ dataSet, labels = self.load_dataSet() # 从网上下载数据集 dataSet = self.maxmin(dataSet) # 归一化 row = (np.array(row) - self.min_value) / self.ranges # 归一化 k_list = [] # k_list用来保证k长度的列表 # 取k个距离最小值对应的标签及对应的距离放到k_list中,统计各自出现的次数 for i, _ in enumerate(dataSet): distance = round(np.linalg.norm(dataSet[i, :] - row), 4) # 计算欧式距离 new_tuple = (labels[i], distance) if len(k_list) < k: # 如果k个点没选完,就直接填充并排序 k_list.append(new_tuple) k_list = sorted(k_list, key=lambda x:x[1]) else: # 如果k个点已选完,那么找到第一次大于该距离的点并插入 k_list = self._replace(new_tuple, k_list) # 计算分类 return self._k_count(k_list) def _replace(self, new_tuple, k_list): """更新k_list""" k = len(k_list) for j, old_tuple in enumerate(k_list): if new_tuple[1] < old_tuple[1]: k_list.insert(j, new_tuple) # 在这个位置插入当前的new_tuple k_list = k_list[: k] # 仍然选取前k个点 break return k_list def _k_count(self, k_list): label_list = [tup[0] for _, tup in enumerate(k_list)] # unique_label = list(set(label_list)) label_count = [(label, label_list.count(label)) for label in unique_label] self.label_count = sorted(label_count, key=lambda x:x[1], reverse=True) return self.label_count[0][0] def test(self): init_dataSet = np.array(self.init_dataSet) np.random.shuffle(init_dataSet) test_records = int(0.1 * init_dataSet.shape[0]) test_data, test_labels = init_dataSet[: test_records, :3], init_dataSet[: test_records, 3] k = 5 test_outcome = [] for i, row in enumerate(test_data): label = self.classify(row, k) test_outcome.append(label == test_labels[i]) print("The real label is %d, the test label is %d." % (test_labels[i], label)) accuracy = sum(test_outcome) / test_data.shape[0] print("The KNN accuracy is %d." % accuracy)

      classify为k-近邻算法的核心代码,比较简单的一个逻辑。给定k值,就找距离最近的k个样本,计算里面各类别的数量,并认定数量最多的类别是测试row的类别。

    knn = KNN()
    knn.load_dataSet()
    knn.plot(knn.init_dataSet)
    knn.test()
    # knn.classify([26052, 1.441871, 0.805124], 10)  # 测试单个数据分类

      2.sikit-learn实现

    from sklearn.neighbors import KNeighborsClassifier

    loader = LoadDataSet() # 上面代码的LoadDataSet() dataSet, labels = loader.load_dataSet() # 上面代码的load_dataSet() neigh = KNeighborsClassifier(n_neighbors=10) neigh.fit(dataSet, labels) print(neigh.predict([dataSet[0]])) print(neigh.predict_proba([dataSet[0]])) print(dataSet[0], labels[0])

      sikit-learn在介绍k近邻时又提供了四个跟其有关的算法。RadiusNeighborsClassifier、 KNeighborsRegressor 、 RadiusNeighborsRegressor 、 NearestNeighbors。

      第1个是以指定的半径来确定邻近样本的个数,并计算各类别的数量。第4个是最近邻,极端情况,相当于k=1。其余两个用于回归,不再赘述。

  • 相关阅读:
    转载:稳定性,鲁棒性和非脆弱性的精辟解读
    BZOJ 2806: [Ctsc2012]Cheat(单调队列优化dp+后缀自动机)
    CF 235C. Cyclical Quest(后缀自动机)
    BZOJ 5137: [Usaco2017 Dec]Standing Out from the Herd(后缀自动机)
    2019/2/28 考试记录
    后缀自动机的应用
    CF 452E. Three strings(后缀数组+并查集)
    BZOJ 2281: [Sdoi2011]黑白棋(dp+博弈论)
    CF 39E. What Has Dirichlet Got to Do with That?(记忆化搜索+博弈论)
    LUOGU P4783 【模板】矩阵求逆(高斯消元)
  • 原文地址:https://www.cnblogs.com/kuaizifeng/p/9105950.html
Copyright © 2011-2022 走看看