zoukankan      html  css  js  c++  java
  • KNN算法实现手写数字

    from numpy import *
    import operator
    from os import listdir
    
    
    def classify0(inX, dataSet, labels, k):
        dataSetSize = dataSet.shape[0]
        diffMat = tile(inX, (dataSetSize,1)) - dataSet
        sqDiffMat = diffMat ** 2
        sqDistances = sqDiffMat.sum(axis=1)
        distances = sqDistances ** 0.5
        sortedDistIndicies = distances.argsort()
        classCount = {}
        for i in range(k):
            voteIlabel = labels[sortedDistIndicies[i]]
            classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
        sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]
    
    
    def img2Vector(filename):
        returnVect = zeros((1,1024))
        # print(returnVect)
        fr = open(filename)
        for i in range(32):
            lineStr = fr.readline()
            for j in range(32):
                returnVect[0,32*i+j] = int(lineStr[j])
        return returnVect
    
    
    def handwritingClassTest():
        hwLabels = []
        trainingFileList = listdir('trainingDigits')
        m = len(trainingFileList)
        trainingMat = zeros((m,1024))
        for i in range(m):
            fileNameStr = trainingFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            hwLabels.append(classNumStr)
            trainingMat[i,:] = img2Vector('trainingDigits/%s'%fileNameStr)
        testFileList = listdir('testDigits')
        errorCount = 0.0
        mTest = len(testFileList)
        for i in range(mTest):
            fileNameStr = testFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            vectorUnderTest = img2Vector('testDigits/%s'%fileNameStr)
            classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)
            print("the classifier came back with:%d,the real answer is :%d"%(classifierResult,classNumStr))
            if (classifierResult != classNumStr):
                errorCount += 1
        print("the total number of errors is :%d"%errorCount)
        print("the total error rate is: %f"%(errorCount/float(mTest)))
    
    handwritingClassTest()

    测试集+训练集数据地址:https://i.cnblogs.com/Files.aspx

    knn.rar

  • 相关阅读:
    CSV文件读取类
    一个参数处理类
    记一个mysql的问题
    php问题小记
    wsl开nginx和php-fpm遇到的几个小问题
    debian apache2.4 virtual host 使用
    debian 安装 apache2和php7
    杂记整理三:php、thinkphp和sql
    杂记整理二:linux与程序安装
    杂记整理一:javascript, jQuery 以及 ECMAscript
  • 原文地址:https://www.cnblogs.com/ncuhwxiong/p/9460380.html
Copyright © 2011-2022 走看看