zoukankan      html  css  js  c++  java
  • KNN算法原理(python代码实现)

    kNN(k-nearest neighbor algorithm)算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。简单地说,K-近邻算法采用测量不同特征值之间的距离方法进行分类。 
    - 优点:精度高、对异常值不敏感、无数据输入假定。 
    - 缺点:计算复杂度高、空间复杂度高。 
    - 适用数据范围:数值型和标称型。

    举个简单的例子,一群男生和一群女生,我们知道他们的身高和性别。 
    如下表格:

    身高性别
    165
    166
    169
    175
    178
    180
    190
    179 未知

    这里有8位同学,已知所有人的身高,其中7个人的性别,还有位同学只知道身高,那么这位同学可能是男生还是女生。

    下面我们利用KNN算法来分类: 
    1. 求这位同学跟其他同学的身高差; 
    2. 设定一个K值,选择跟这位同学身高相差最小的K个同学; 
    3. 在这K个同学中,哪种性别的人多,就认为这位同学属于哪种性别。

    举例:设K为5

    计算这位同学与其他同学的身高差的绝对值:如下

    身高差的绝对值性别
    14
    13
    10
    4
    1
    1
    11

    与这位同学身高差最接近的5位同学中,有4位男生,1位女生,所以我们认为这位同学也是男生。

    现实当中,肯定不止考虑身高,还要考虑体重、头发长度等等因素。算法的思想还是一样,这时候“距离”的计算公式为欧氏距离:

    在多个特征,多个分类的情况下,KNN算法思想: 
    1. 计算预分类的与样本中的欧氏距离(当然还有其他距离); 
    2. 选择距离最小的K的样本; 
    3. 把预分类归为:K个样本中,类别最多的那个类别。

    下面是KNN的python代码实现:

     1 from numpy import *
     2 import operator
     3 from os import listdir
     4 filename_train='C:\Users\Administrator\Desktop\digits\trainingDigits\'
     5 filename_test='C:\Users\Administrator\Desktop\digits\testDigits\'
     6 #KNN算法函数
     7 def classify0(inX, dataSet, labels, k):
     8     dataSetSize = dataSet.shape[0]
     9     diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    10     sqDiffMat = diffMat ** 2
    11     sqDistances = sqDiffMat.sum(axis=1)
    12     distances = sqDistances ** 0.5
    13     sortedDistIndicies = distances.argsort()
    14     classCount = {}
    15     for i in range(k):
    16         voteIlabel = labels[sortedDistIndicies[i]]
    17         classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    18     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    19     return(sortedClassCount[0][0])
    20 
    21 #将图像转化为向量
    22 def img2vector(filename):
    23     returnVect = zeros((1, 1024))
    24     fr = open(filename)
    25     for i in range(32):
    26         lineStr = fr.readline()
    27         for j in range(32):
    28             returnVect[0, 32 * i + j] = int(lineStr[j])
    29     return(returnVect)
    30 
    31 #手写分类
    32 def handwritingClassTest():
    33     hwLabels = []
    34     trainingFileList = listdir(filename_train)  # 载入数据所在目录
    35     m = len(trainingFileList)
    36     trainingMat = zeros((m, 1024))
    37     for i in range(m):
    38         fileNameStr = trainingFileList[i]
    39         fileStr = fileNameStr.split('.')[0]
    40         classNumStr = int(fileStr.split('_')[0])
    41         hwLabels.append(classNumStr)
    42         trainingMat[i, :] = img2vector(filename_train+fileNameStr)
    43     testFileList = listdir(filename_test)
    44     errorCount = 0.0
    45     mTest = len(testFileList)
    46     for i in range(mTest):
    47         fileNameStr = testFileList[i]
    48         fileStr = fileNameStr.split('.')[0]
    49         classNumStr = int(fileStr.split('_')[0])
    50         vectorUnderTest = img2vector(filename_test+fileNameStr)
    51         classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
    52         print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
    53         if (classifierResult != classNumStr): errorCount += 1.0
    54     print("
    the total number of errors is: %d" % errorCount)
    55     print("
    the total error rate is: %f" % (errorCount / float(mTest)))
    56 
    57 if __name__=='__main__':
    58     handwritingClassTest()

    结果如下,效果还不错。 
    the total number of errors is: 11 
    the total error rate is: 0.011628

    参考: 
    1. 《Machine Learning in Action》 
    2. scikit-learn KNeighborsClassifier 官网

  • 相关阅读:
    Netty源码分析之ByteBuf引用计数
    GitHub git push大文件失败(write error: Broken pipe)完美解决
    Windows10 Docker安装详细教程
    全面的Docker快速入门教程
    十本你不容错过的Docker入门到精通书籍推荐
    CentOS 8.4安装Docker
    postgres之一条sql查询总数及部分数据
    neo4j相关操作
    git上传大文件
    分布式文件系统fastdfs安装以及python调用
  • 原文地址:https://www.cnblogs.com/luozeng/p/8605033.html
Copyright © 2011-2022 走看看