zoukankan      html  css  js  c++  java
  • Python实现KNN算法及手写程序识别

    1.Python实现KNN算法

    输入:inX:与现有数据集(1xN)进行比较的向量
       dataSet:已知向量的大小m数据集(NxM)
       个标签:数据集标签(1xM矢量)
       k:用于比较的邻居数(应为奇数)
    输出:最受欢迎的类标签(归类问题)

      1 # -*- coding: utf-8 -*-
      2 """
      3 Created on Sun Apr 16 23:01:54 2017
      4 
      5 @author: SimonsZhao
      6 """ 10 kNN: k Nearest Neighbors
     12 Input:      inX: vector to compare to existing dataset (1xN)
     13             dataSet: size m data set of known vectors (NxM)
     14             labels: data set labels (1xM vector)
     15             k: number of neighbors to use for comparison (should be an odd number)
     17 Output:     the most popular class label 20 '''
     21 from numpy import *
     22 import operator
     23 from os import listdir
     24 
     25 def classify0(inX, dataSet, labels, k):
     26     dataSetSize = dataSet.shape[0]
     27     diffMat = tile(inX, (dataSetSize,1)) - dataSet
     28     sqDiffMat = diffMat**2
     29     sqDistances = sqDiffMat.sum(axis=1)
     30     distances = sqDistances**0.5
     31     sortedDistIndicies = distances.argsort()     
     32     classCount={}          
     33     for i in range(k):
     34         voteIlabel = labels[sortedDistIndicies[i]]
     35         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
     36     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
     37     return sortedClassCount[0][0]
     38 
     39 def createDataSet():
     40     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
     41     labels = ['A','A','B','B']
     42     return group, labels
     43 
     44 def file2matrix(filename):
     45     fr = open(filename)
     46     numberOfLines = len(fr.readlines())         #get the number of lines in the file
     47     returnMat = zeros((numberOfLines,3))        #prepare matrix to return
     48     classLabelVector = []                       #prepare labels return   
     49     fr = open(filename)
     50     index = 0
     51     for line in fr.readlines():
     52         line = line.strip()
     53         listFromLine = line.split('	')
     54         returnMat[index,:] = listFromLine[0:3]
     55         classLabelVector.append(int(listFromLine[-1]))
     56         index += 1
     57     return returnMat,classLabelVector
     58     
     59 def autoNorm(dataSet):
     60     minVals = dataSet.min(0)
     61     maxVals = dataSet.max(0)
     62     ranges = maxVals - minVals
     63     normDataSet = zeros(shape(dataSet))
     64     m = dataSet.shape[0]
     65     normDataSet = dataSet - tile(minVals, (m,1))
     66     normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide
     67     return normDataSet, ranges, minVals
     68    
     69 def datingClassTest():
     70     hoRatio = 0.50      #hold out 10%
     71     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
     72     normMat, ranges, minVals = autoNorm(datingDataMat)
     73     m = normMat.shape[0]
     74     numTestVecs = int(m*hoRatio)
     75     errorCount = 0.0
     76     for i in range(numTestVecs):
     77         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
     78         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i])
     79         if (classifierResult != datingLabels[i]): errorCount += 1.0
     80     print "the total error rate is: %f" % (errorCount/float(numTestVecs))
     81     print errorCount
     82     
     83 def img2vector(filename):
     84     returnVect = zeros((1,1024))
     85     fr = open(filename)
     86     for i in range(32):
     87         lineStr = fr.readline()
     88         for j in range(32):
     89             returnVect[0,32*i+j] = int(lineStr[j])
     90     return returnVect
     91 
     92 def handwritingClassTest():
     93     hwLabels = []
     94     trainingFileList = listdir('trainingDigits')           #load the training set
     95     m = len(trainingFileList)
     96     trainingMat = zeros((m,1024))
     97     for i in range(m):
     98         fileNameStr = trainingFileList[i]
     99         fileStr = fileNameStr.split('.')[0]     #take off .txt
    100         classNumStr = int(fileStr.split('_')[0])
    101         hwLabels.append(classNumStr)
    102         trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
    103     testFileList = listdir('testDigits')        #iterate through the test set
    104     errorCount = 0.0
    105     mTest = len(testFileList)
    106     for i in range(mTest):
    107         fileNameStr = testFileList[i]
    108         fileStr = fileNameStr.split('.')[0]     #take off .txt
    109         classNumStr = int(fileStr.split('_')[0])
    110         vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
    111         classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
    112         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
    113         if (classifierResult != classNumStr): errorCount += 1.0
    114     print "
    the total number of errors is: %d" % errorCount
    115     print "
    the total error rate is: %f" % (errorCount/float(mTest))

    2.数据集(测试集合训练集)

    3.KNN测试结果

  • 相关阅读:
    test
    TCP/IP状态转换图
    用Python操作Excel,实现班级成绩的统计
    树莓派介绍和安装树莓派系统遇到的坑,好痛苦啊
    Eclipse-jee-oxygen-3A版运行时出现Could not create the Java virtual machine?
    eclipse搭建简单的web服务,使用tomcat服务
    嵌入式【杂记--手机芯片与pc】
    tomcat启动不了?
    Selenium的使用
    使用PhantomJS
  • 原文地址:https://www.cnblogs.com/jackchen-Net/p/6800275.html
Copyright © 2011-2022 走看看