zoukankan      html  css  js  c++  java
  • KNN算法的简单实现

    一  算法原理:已知一个训练样本集,其中每个训练样本都有自己的标记(label),即我们知道样本集中每一个样本数据与所属分类的对应关系。输入没有标记的新数据后,将新数据的每个特征与样本集中的数据对应的特征进行比较,然后提取样本集中特征最相似数据的分类标记。一般的,我们选择样本集中前k个最相似的数据分类标签,其中出现次数最多的分类作为我们新数据的分类标记。简单的说,k_近邻算法采用测量不同特征值之间的距离方法进行分类。

    算法优点: 精度高、对异常值不敏感,无数据输入假设。

    算法缺点: 由于要将每个待分类的数据特征与样本集中的每个样例进行对应特征距离的计算,所以计算的时间空间复杂度高。

    二  算法的实现(手写体识别)

    1.数据准备:采用的是32*32像素的黑白图像(0-9,每个数字大约200个样本,trainingDigits用于数据分类器训练,testDigits用于测试),这里为了方便理解,将图片转换成了文本格式。

    2.代码实现:

                  将图片转化为一个向量,我们把一个32*32的二进制图像矩阵转化为一个1*1024的向量,编写一个函数vector2d,如下代码

      1 def vector2d(filename):
      2     rows = 32
      3     cols = 32
      4     imgVector = zeros((1,rows * cols))
      5     fileIn = open(filename)
      6     for row in xrange(rows):
      7         lineStr = fileIn.readline()
      8         for col in xrange(cols):
      9             imgVector[0,row *32 + col] = int(lineStr[col])
     10     return imgVector
     11 
    View Code

                  trainingData set 和testData set 的载入

      1 '''load dataSet '''
      2 def loadDataSet():
      3     print '....Getting training data'
      4     dataSetDir =  'D:/pythonCode/MLCode/KNN/'
      5     trainingFileList = os.listdir(dataSetDir + 'trainingDigits')
      6     numSamples = len(trainingFileList)
      7 
      8     train_x = zeros((numSamples,1024))
      9     train_y = []
     10     for i  in xrange(numSamples):
     11         filename = trainingFileList[i]
     12         train_x[i,:] = vector2d(dataSetDir + 'trainingDigits/%s'%filename)
     13         label = int(filename.split('_')[0])
     14         train_y.append(label)
     15     ''' ....Getting testing data...'''
     16     print '....Getting testing data...'
     17     testFileList = os .listdir(dataSetDir + 'testDigits')
     18     numSamples = len(testFileList)
     19     test_x = zeros((numSamples,1024))
     20     test_y = []
     21     for i in xrange(numSamples):
     22         filename = testFileList[i]
     23         test_x[i,:] = vector2d(dataSetDir + 'testDigits/%s'%filename)
     24         label = int(filename.split('_')[0])
     25         test_y.append(label)
     26 
     27     return train_x,train_y,test_x,test_y
    View Code

                    分类器的构造

      1 from numpy import *
      2 
      3 import os
      4 
      5 def kNNClassify(newInput,dataSet,labels,k):
      6     numSamples = dataSet.shape[0]
      7 
      8     diff = tile(newInput,(numSamples,1)) - dataSet
      9     squaredDiff = diff ** 2
     10     squaredDist = sum(squaredDiff,axis = 1)
     11     distance = squaredDist ** 0.5
     12 
     13     sortedDistIndex = argsort(distance)
     14 
     15     classCount =  {}
     16     for i in xrange(k):
     17         votedLabel = labels[sortedDistIndex[i]]
     18         classCount[votedLabel] = classCount.get(votedLabel,0) + 1
     19 
     20     maxValue = 0
     21     for key,value in classCount.items():
     22         if maxValue < value:
     23             maxValue = value
     24             maxIndex = key
    View Code

    分类测试

      1 def testHandWritingClass():
      2     print 'load data....'
      3     train_x,train_y,test_x,test_y = loadDataSet()
      4     print'training....'
      5 
      6     print'testing'
      7     numTestSamples = test_x.shape[0]
      8     matchCount = 0.0
      9     for i in xrange(numTestSamples):
     10         predict = kNNClassify(test_x[i],train_x,train_y,3)
     11         if predict != test_y[i]:
     12 
     13             print 'the predict is ',predict,'the target value is',test_y[i]
     14 
     15         if predict == test_y[i]:
     16             matchCount += 1
     17     accuracy = float(matchCount)/numTestSamples
     18 
     19     print'The accuracy is :%.2f%%'%(accuracy * 100)
    View Code

               测试结果 

      1 testHandWritingClass()
      2 load data....
      3 ....Getting training data
      4 ....Getting testing data...
      5 training....
      6 testing
      7 the predict is  7 the target value is 1
      8 the predict is  9 the target value is 3
      9 the predict is  9 the target value is 3
     10 the predict is  3 the target value is 5
     11 the predict is  6 the target value is 5
     12 the predict is  6 the target value is 8
     13 the predict is  3 the target value is 8
     14 the predict is  1 the target value is 8
     15 the predict is  1 the target value is 8
     16 the predict is  1 the target value is 9
     17 the predict is  7 the target value is 9
     18 The accuracy is :98.84%
    View Code

    注:以上代码运行环境为Python2.7.11

    从上面结果可以看出knn 分类效果还不错,在我看来,knn就是简单粗暴,就是把未知分类的数据特征与我们分类好的数据特征进行比对,选择最相似的标记作为自己的分类,辣么问题来了,如果我们的新数据的特征在样本集中比较少见,这时候就会出现问题,分类错误的可能性非常大,反之,如果样例集中某一类的样例比较多,那么新数据被分成该类的可能性就会大,如何保证分类的公平性,我们就需要进行加权了。

    补充:关于K值的选取,当k越小时,分类结果对原数据的敏感性越强,易受到异常数据的影响,即模型越复杂。

    数据来源:http://download.csdn.net/download/qq_17046229/7625323

  • 相关阅读:
    PAT1001
    关于yahoo.com.cn邮箱导入Gmail邮箱验证异常的机制解析及解决办法
    浙大机试感受
    PAT1002
    mysql修改密码后无法登陆问题
    Windows 不能在 本地计算机 启动 OracleDBConsoleorcl
    Deprecated: Function ereg_replace() is deprecated
    PHP中静态方法(static)与非静态方法的使用及区别
    微信小程序开发,weui报“渲染层错误”的解决办法
    Android系统下载管理DownloadManager功能介绍及使用示例
  • 原文地址:https://www.cnblogs.com/lpworkstudyspace1992/p/5470621.html
Copyright © 2011-2022 走看看