zoukankan      html  css  js  c++  java
  • 数据挖掘实践(13):基础理论(十三)KNN算法(二)KNN算法的实现

    算法实现

    def classify0(inX, dataSet, labels, k):
     dataSetSize = dataSet.shape[0]
     #距离度量 度量公式为欧⽒距离
     diffMat = np.tile(inX, (dataSetSize,1))-dataSet
     sqDiffMat = diffMat**2
     sqDistances = sqDiffMat.sum(axis=1)
     distances = sqDistances**0.5
     
     #将距离排序:从⼩到⼤
     sortedDistIndicies = distances.argsort()
     #选取前K个最短距离, 选取这K个中最多的分类类别
     classCount={}
     for i in range(k):
     voteIlabel = labels[sortedDistIndicies[i]]
     classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
     sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1),reverse=True)
     return sortedClassCount[0][0]
    def img2vector(filename):
     returnVect = np.zeros((1,1024))
     fr = open(filename)
     for i in range(32):
     lineStr = fr.readline()
     for j in range(32):
     returnVect[0,32*i+j] = int(lineStr[j])
     return returnVect
    def handwritingClassTest():
     # 1. 导⼊训练数据
     hwLabels = []
     trainingFileList = os.listdir('2.KNN/trainingDigits') # load the training
    set
     m = len(trainingFileList)
     trainingMat = np.zeros((m, 1024))
     # hwLabels存储0~9对应的index位置, trainingMat存放的每个位置对应的图⽚向量
     for i in range(m):
     fileNameStr = trainingFileList[i]
     fileStr = fileNameStr.split('.')[0] # take off .txt
     classNumStr = int(fileStr.split('_')[0])
     hwLabels.append(classNumStr)
     # 将 32*32的矩阵->1*1024的矩阵
     trainingMat[i, :] = img2vector('2.KNN/trainingDigits/%s' %
    fileNameStr)
     # 2. 导⼊测试数据
     testFileList = os.listdir('2.KNN/testDigits') # iterate through the test
    set
     errorCount = 0.0
     mTest = len(testFileList)
     for i in range(mTest):
     fileNameStr = testFileList[i]
     fileStr = fileNameStr.split('.')[0] # take off .txt
     classNumStr = int(fileStr.split('_')[0])
     vectorUnderTest = img2vector('2.KNN/testDigits/%s' % fileNameStr)
     classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
     print ("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
     if (classifierResult != classNumStr): errorCount += 1.0
     print ("
    the total number of errors is: %d" % errorCount)
     print ("
    the total error rate is: %f" % (errorCount / float(mTest)))
    import os
    import numpy as np
    import operator
    handwritingClassTest()
    优点:⾼的准确率,对于异常值不敏感
    缺点:空间。时间复杂度太⾼了
  • 相关阅读:
    iOS 与 惯性滚动
    前端性能优化--为什么DOM操作慢?
    React虚拟DOM浅析
    DOM性能瓶颈与Javascript性能优化
    React 组件性能优化
    重绘与回流——影响浏览器加载速度
    移动前端开发之viewport的深入理解
    [转] 前后端分离开发模式的 mock 平台预研
    [Unity3D]Unity3D游戏开发之Lua与游戏的不解之缘终结篇:UniLua热更新全然解读
    关联规则( Association Rules)之频繁模式树(FP-Tree)
  • 原文地址:https://www.cnblogs.com/qiu-hua/p/14322144.html
Copyright © 2011-2022 走看看