zoukankan      html  css  js  c++  java
  • 我眼中的K-近邻算法

    有一句话这样说:如果你想了解一个人,你可以从他身边的朋友开始。

    如果与他交往的好友都是一些品行高尚的人,那么可以认为这个人的品行也差不了。

    其实古人在这方面的名言警句,寓言故事有很多。例如:人以类聚,物以群分。近朱者赤近墨者黑

    其实K-近邻算法和古人的智慧想通,世间万物息息相通,你中有我,我中有你。

    K-近邻原理:

    存在一个训练集,我们知道每一个样本的标签,例如训练样本是一群人,他们都有相应特征,例如,爱喝酒或爱看书或逛窑子或打架斗殴或乐于助人等等,并且知道他们是好人还是坏人,然后来了一个新人(新样本),然后把新样本的特征与样本集中数据对应的特征进行比较,然后算法提取集中特征最相似数据的分类标签,就是比较这个新人具有的品行与那一群人中谁的品行相近,选取出样本集中数据中前K个数据(这就是K的来历),然后查看这K个数据的标签,选取出现最多类作为新样本的分类。就是查看选出的这些人,看看是好人多还是坏人多,如果好人多,那么我们就确定这个新人是好人。

    K-近邻算法没有训练过程,它直接对新样本进行分类。

     代码来源机器学习实战,python3.7可用,详细注释:

    #coding=utf-8
    from numpy import *
    import operator
    import os,sys
    
    def createDataSet():
        #数组转换成矩阵
        group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
        labels = ['A','A','B','B']
        return group,labels
    
    #inx为测试样本
    def classify0(inx,dataSet,labels,k):
        #shape[0]给出行数,shape[1]列数
        dataSetSize = dataSet.shape[0]
        #把inx矩阵的每一行复制dataSetSize次,列不复制
        #为了把该样本与训练集中每一个样本计算出距离
        #计算欧氏距离
        diffMat = tile(inx,(dataSetSize,1)) - dataSet
        #距离的平方差
        sqDiffMat = diffMat**2
        #把数组每一行求和
        sqDistances = sqDiffMat.sum(axis=1)
        distances = sqDistances**0.5
        #argsort 从小到大排序,但是返回的是下标
        sortedDistIndices = distances.argsort()
        classCount = {}
        #k是前k个最小距离
        for i in range(k):
            #把最小距离对应的标签赋值给voteIlabel
            voteIlabel = labels[sortedDistIndices[i]]
            #投票算法,统计前k个数据的标签类型及其出现的个数
            classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
        #排序选出出现次数最多的标签,(注意:Python 3 renamed dict.iteritems() -> dict.items())
        sortedClassCount = sorted(classCount.items(),
         key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]
    
    def file2matrix(filename):
        fr = open(filename)
        #文件有多少行
        arrayOLines = fr.readlines()
        numberOfLines = len(arrayOLines)
        #返回一个(numberOfLines,3)的零矩阵
        returnMat = zeros((numberOfLines,3))
        classLabelVector = []
        index = 0
        for line in arrayOLines:
            #去除字符串的首尾的字符(空格,回车)
            line = line.strip()
            listFromLine = line.split('	')
            #复制行给returnMat
            returnMat[index,:] = listFromLine[0:3]
            #获取标签,这里需要把字符串类型转换成int类型
            if listFromLine[-1] == 'largeDoses':
                classLabelVector.append(3)
            elif listFromLine[-1] == 'smallDoses':
                classLabelVector.append(2)
            elif listFromLine[-1] == 'didntLike':
                classLabelVector.append(1)
            else:
                classLabelVector.append(int(listFromLine[-1]))
            index += 1
        return returnMat,classLabelVector
    
    #分析数据
    '''
    控制台输入
    import matplotlib
    import matplotlib.pyplot as plt
    #定义一个图像窗口
    fig = plt.figure()
    #意思是窗口背划分成1*1个格子,使用第一个格子
    ax = fig.add_subplot(111)
    #描绘散点图
    ax.scatter(datingDataMat[:,1],datingDataMat[:,2])
    #使用颜色来分辨
    ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))
    plt.show()
    
    '''
    #给出的数据集往往会遇见这样的问题,就是每一个特征值的取值不在
    #同一个数量级,有的取值会很大,这样会严重影响结果的准确性
    #所以要归一化特征值到0~1之间
    #公式:newValue = (oldValue-min)/(max-min)
    def autoNorm(dataSet):
        #返回每一列最小值(1,m)
        minVals = dataSet.min(0)
        #返回每一列最大值
        maxVals = dataSet.max(0)
        ranges = maxVals - minVals
        normDataSet = zeros(shape(dataSet))
        m = dataSet.shape[0]
        normDataSet = dataSet - tile(minVals,(m,1))
        normDataSet = normDataSet/tile(ranges,(m,1))
        return normDataSet,ranges,minVals
    
    #分类器针对约会网站的测试代码
    def datingClassTest():
        hoRatio = 0.10
        datingDataMat,datingLabels = file2matrix('datingTestSet.txt')
        #归一化
        normMat,ranges,minVals = autoNorm(datingDataMat)
        m = normMat.shape[0]
        #选取数据集的10%作为测试集
        numTestVecs = int(m*hoRatio)
        errorCount = 0.0
        #循环对测试集进行分类,然后计算准确率
        for i in range(numTestVecs):
            classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
            print ("the classifier came back with:%d,the real answer is:%d"%(classifierResult,datingLabels[i]))
            if (classifierResult != datingLabels[i]):errorCount += 1.0
        print ("the total error rate is:%f"%(errorCount/float(numTestVecs)))
    
        
    def classifyPerson():
        resultList = ['not at all', 'in small doses', 'in large doses']
        #python3 输入是input
        percentTats = float(input("percentage of time spent playing video games?"))
        ffMiles = float(input("frequent flier miles earned per year?"))
        iceCream = float(input("liters of ice cream consumed per year?"))
        datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
        normMat, ranges, minVals = autoNorm(datingDataMat)
        inArr = array([ffMiles, percentTats, iceCream])
        classifierResult = classify0((inArr - minVals)/ranges, normMat, datingLabels, 3)
        print ("You will probably like this person: %s" % resultList[classifierResult - 1])
        
        
        #识别手写数字
        #把32*32的矩阵转换成1*1024矩阵
    def img2vector(filename):
        returnVect = zeros((1,1024))
        fr = open(filename)
        for i in range(32):
            lineStr = fr.readline()
            for j in range(32):
                returnVect[0,32*i+j] = int(lineStr[j])
        return returnVect
        
    def handwritingClassTest():
        hwLabels= []
        #获取目录的内容
        trainingFileList = os.listdir('trainingDigits')
        m = len(trainingFileList)
        trainingMat = zeros((m,1024))
        for i in range(m):
            #从文本文件的名称中截取是什么数字
            fileNameStr = trainingFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            hwLabels.append(classNumStr)
            trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
        testFileList = os.listdir('testDigits')
        errorCount = 0.0
        mTest = len(testFileList)
        for i in range(mTest):
            fileNameStr = testFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
            classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)
            #计算精确性
            print ("the classifier came back with:%d,the real answer is:%d" % (classifierResult,classNumStr))
            if (classifierResult != classNumStr):errorCount += 1.0
        print ("
     the total number of errors is:%d" % errorCount)
        print ("
     the total error rate is:%f" % (errorCount/float(mTest)))

    算法主要有两个主要的步骤:

    (1)求解两向量之间的距离来比较相似性:

      

     (2) 排序选出前K个相似点,筛选出出现频率最高的类别

       代码中直接调用排序算法,如果对于大量数据,排序会很耗费时间,所以可以优化排序算法:Kd树

            筛选评论最高的是通过投票的方式。

     上面代码中包括了识别手写体的代码,依然用的是欧氏距离,之前做过一个使用神经网络训练做的手写体数字识别,我想比较这两个算法的准确性。

    kNN算法没有训练过程,算法也十分简单,但是在实践的过程中我发现,KNN具有局限性。我的做法是

    kNN识别手写体:

    先把数字的灰度图转换成32*32的字符文件的格式,然后使用kNN算法,发现不同的测试集的准确性相差很大,如果使用和训练集相近的测试集去测试,所谓相近就是说数字的大小,粗细都会影响识别的准确性,所以我用不同的测试集得到的结果完全不同,如果用训练集去作为测试集使用,准确率会达到99%,但是换一个不同的测试集,准确率就会降到34%左右(比蒙的好一点点)。如果要提高准确性,必须加大

    训练集(尽量包含所有的手写体类型),再调整K的取值,如果那样的话,做一次分类,就要对大量的数据集进行比对,排序选出相近的,这样效率非常低。

    神经网络识别手写体:

    在训练的过程中会消耗时间,但是一旦模型训练完毕,准确率会很高。

    所以说kNN算法适合数据集较小的情况的分类。

    注意:K-近邻是监督学习,K-Means是无监督学习

  • 相关阅读:
    CF1474C Array Destruction 题解 贪心
    洛谷P1854 花店橱窗布置 题解 2D/0D型动态规划
    POJ1704 Georgia and Bob 题解 阶梯博弈
    HDU1848 Fibonacci again and again 题解 SG函数
    SG函数简要学习笔记
    洛谷P2868 [USACO07DEC]Sightseeing Cows G 题解 01分数规划+SPFA判负环
    洛谷P4322 [JSOI2016]最佳团体 题解 01分数规划+树上背包
    从零开发SIP客户端(Windows)踩坑实录
    richedit禁用输入法的实现
    VS2013无法加载解决方案中的项目(转)
  • 原文地址:https://www.cnblogs.com/zhxuxu/p/9640656.html
Copyright © 2011-2022 走看看