zoukankan      html  css  js  c++  java
  • 数据挖掘实践(13):基础理论(十三)KNN算法(二)KNN算法的实现

    算法实现

    def classify0(inX, dataSet, labels, k):
     dataSetSize = dataSet.shape[0]
     #距离度量 度量公式为欧⽒距离
     diffMat = np.tile(inX, (dataSetSize,1))-dataSet
     sqDiffMat = diffMat**2
     sqDistances = sqDiffMat.sum(axis=1)
     distances = sqDistances**0.5
     
     #将距离排序:从⼩到⼤
     sortedDistIndicies = distances.argsort()
     #选取前K个最短距离, 选取这K个中最多的分类类别
     classCount={}
     for i in range(k):
     voteIlabel = labels[sortedDistIndicies[i]]
     classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1),reverse=True)
     return sortedClassCount[0][0]
    def img2vector(filename):
     returnVect = np.zeros((1,1024))
     fr = open(filename)
     for i in range(32):
     lineStr = fr.readline()
     for j in range(32):
     returnVect[0,32*i+j] = int(lineStr[j])
     return returnVect
    def handwritingClassTest():
     # 1. 导⼊训练数据
     hwLabels = []
     trainingFileList = os.listdir('2.KNN/trainingDigits') # load the training
    set
     m = len(trainingFileList)
     trainingMat = np.zeros((m, 1024))
     # hwLabels存储0~9对应的index位置, trainingMat存放的每个位置对应的图⽚向量
     for i in range(m):
     fileNameStr = trainingFileList[i]
     fileStr = fileNameStr.split('.')[0] # take off .txt
     classNumStr = int(fileStr.split('_')[0])
     hwLabels.append(classNumStr)
     # 将 32*32的矩阵->1*1024的矩阵
     trainingMat[i, :] = img2vector('2.KNN/trainingDigits/%s' %
    fileNameStr)
     # 2. 导⼊测试数据
     testFileList = os.listdir('2.KNN/testDigits') # iterate through the test
    set
     errorCount = 0.0
     mTest = len(testFileList)
     for i in range(mTest):
     fileNameStr = testFileList[i]
     fileStr = fileNameStr.split('.')[0] # take off .txt
     classNumStr = int(fileStr.split('_')[0])
     vectorUnderTest = img2vector('2.KNN/testDigits/%s' % fileNameStr)
     classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
     print ("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
     if (classifierResult != classNumStr): errorCount += 1.0
     print ("
    the total number of errors is: %d" % errorCount)
     print ("
    the total error rate is: %f" % (errorCount / float(mTest)))
    import os
    import numpy as np
    import operator
    handwritingClassTest()
    优点:⾼的准确率,对于异常值不敏感
    缺点:空间。时间复杂度太⾼了
  • 相关阅读:
    变量 常量 Python变量内存管理 赋值方式 注释
    leetcode 两数之和 整数反转 回文数 罗马数字转整数
    计算机基础之编程
    列表,集合,元组,字典
    小练习
    Ansi 与 Unicode 字符串类型的互相转换
    UVALive
    UVA
    UVA 10651 Pebble Solitaire 状态压缩dp
    UVA 825 Walkiing on the safe side
  • 原文地址:https://www.cnblogs.com/qiu-hua/p/14322144.html
Copyright © 2011-2022 走看看