zoukankan      html  css  js  c++  java
  • K临近算法

    邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。

    基本思想

    分类思想比较简单,从训练样本中找出K个与其最相近的样本,然后看这k个样本中哪个类别的样本多,则待判定的值(或说抽样)就属于这个类别。

    kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。

    算法流程


    1. 准备数据,对数据进行预处理
    2. 选用合适的数据结构存储训练数据和测试元组
    3. 设定参数,如k
    4.维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组。随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测    试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
    5. 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L 与优先级队列中的最大距离Lmax
    6. 进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。
    7. 遍历完毕,计算优先级队列中k 个元组的多数类,并将其作为测试元组的类别。
    8. 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k 值。[1] 

    优点


    1.简单,易于理解,易于实现,无需估计参数,无需训练;
    2. 适合对稀有事件进行分类;
    3.特别适合于多分类问题(multi-modal,对象具有多个类别标签), kNN比SVM的表现要好。

    缺点

    (1)当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本             占多数。 该算法只计算“最近的”邻居样本,某一类的样本数量很大,那么或者这类样本并不接近目标样本,或者这类样本很靠近目标样本。无论           怎样,数量并不能影响运行结果。

    (2)计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。

    (3)可理解性差,无法给出像决策树那样的规则。

    (4)K值需要预先设定,而不能自适应。


    knn算法用于手写数字识别

    #-*- coding: utf-8 -*-
    from numpy import *
    import operator
    import time
    from os import listdir
    
    
    
    def classify(inputPoint,dataSet,labels,k):
      dataSetSize = dataSet.shape[0]	 #已知分类的数据集(训练集)的行数
      #先tile函数将输入点拓展成与训练集相同维数的矩阵,再计算欧氏距离
      diffMat = tile(inputPoint,(dataSetSize,1))-dataSet  #样本与训练集的差值矩阵
      sqDiffMat = diffMat ** 2					#差值矩阵平方
      sqDistances = sqDiffMat.sum(axis=1)		 #计算每一行上元素的和
      distances = sqDistances ** 0.5			  #开方得到欧拉距离矩阵
      sortedDistIndicies = distances.argsort()	#按distances中元素进行升序排序后得到的对应下标的列表
      #选择距离最小的k个点
      classCount = {}
      for i in range(k):
        voteIlabel = labels[ sortedDistIndicies[i] ]
        classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
      #按classCount字典的第2个元素(即类别出现的次数)从大到小排序
      sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
      return sortedClassCount[0][0]
    
    
    
    #文本向量化 32x32 -> 1x1024
    def img2vector(filename):
      returnVect = []
      fr = open(filename)
      for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
          returnVect.append(int(lineStr[j]))
      return returnVect
    
    
    #构建训练集数据向量,及对应分类标签向量
    def trainingDataSet():
      hwLabels = []
      trainingFileList = listdir('trainingDigits')		   #获取目录内容
      m = len(trainingFileList)
      trainingMat = zeros((m,1024))						  #m维向量的训练集
      for i in range(m):
        fileNameStr = trainingFileList[i]
        hwLabels.append(classnumCut(fileNameStr))
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
      return hwLabels,trainingMat
    
    
    #从文件名中解析分类数字
    def classnumCut(fileName): 
      fileStr = fileName.split('.')[0]  
      classNumStr = int(fileStr.split('_')[0]) 
      return classNumStr
    
    
    #测试函数
    def handwritingTest():
      hwLabels,trainingMat = trainingDataSet()	#构建训练集
      testFileList = listdir('testDigits')		#获取测试集
      errorCount = 0.0							#错误数
      mTest = len(testFileList)				   #测试集总样本数
      t1 = time.time()
      for i in range(mTest):
        fileNameStr = testFileList[i]
        classNumStr = classnumCut(fileNameStr)
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        #调用knn算法进行测试
        classifierResult = classify(vectorUnderTest, trainingMat, hwLabels, 3)
        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
        if (classifierResult != classNumStr): errorCount += 1.0
      print "
    the total number of tests is: %d" % mTest			   #输出测试总样本数
      print "the total number of errors is: %d" % errorCount		   #输出测试错误样本数
      print "the total error rate is: %f" % (errorCount/float(mTest))  #输出错误率
      t2 = time.time()
      print "Cost time: %.2fmin, %.4fs."%((t2-t1)//60,(t2-t1)%60)	  #测试耗时
    
    if __name__ == "__main__":
      handwritingTest()
    

    training及test数据集下载

    版权声明:

  • 相关阅读:
    java后台对上传的图片进行压缩
    Reflections框架,类扫描工具
    commons-httpclient和org.apache.httpcomponents的区别
    sql里面插入语句insert后面的values关键字可省略
    Callable接口、Runable接口、Future接口
    Java多线程之Callable接口的实现
    说说Runnable与Callable
    论坛贴吧问题:如何终止运行时间超时的线程
    使用Future停止超时任务
    spring的@Transactional注解详细用法
  • 原文地址:https://www.cnblogs.com/walccott/p/4957102.html
Copyright © 2011-2022 走看看