zoukankan      html  css  js  c++  java
  • kNN

    import  numpy as np
    import operator
    import matplotlib
    import matplotlib.pyplot as plt
    import os
    
    
    def createDataSet():
        group = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
        labels = ['A', 'A', 'B', 'B']
        return group, labels
    
    def classify0(inX, dataSet, labels, k):
        # kNN算法简单流程
    
        dataSetSize = dataSet.shape[0]
        diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
        sqDiffMat = diffMat**2
        sqDistances = sqDiffMat.sum(axis=1)
        distances = sqDistances**0.5
        sortedDistIndicies = distances.argsort()
        classCount = {}
        for i in range(k):
            voteIlabel = labels[sortedDistIndicies[i]]
            classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
        sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
        return sortedClassCount[0][0]
    
    
    def file2matrix(filename):
        # 将txt文件转化为需要的数据格式
    
        fr = open(filename)
        arrayOLines = fr.readlines()
        numberOfLines = len(arrayOLines)
        returnMat = np.zeros((numberOfLines, 3))
        classLabelVectors = []
        index = 0
        for line in arrayOLines:
            line = line.strip()
            listFromLine = line.split('	')
            returnMat[index, :] = listFromLine[0:3]
            classLabelVectors.append(int(listFromLine[-1]))
            index += 1
        return returnMat, classLabelVectors
    
    
    """
    data, labels = file2matrix("datingTestSet.txt")
    # print(data)
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.scatter(data[:, 0], data[:, 1], 15*np.array(labels), 15*np.array(labels))
    plt.show()
    """
    
    
    def autoNorm(dataSet):
        """
        归一化特征值
        :param dataSet: 训练集数据
        :return:
        """
        minVals = dataSet.min(0)
        maxVals = dataSet.max(0)
        ranges = maxVals - minVals
        normDataSet = np.zeros(np.shape(dataSet))
        m = dataSet.shape[0]
        normDataSet = dataSet - np.tile(minVals, (m, 1))
        normDataSet = normDataSet / np.tile(ranges, (m, 1))
        return normDataSet, ranges, minVals
    
    
    def datingClassTest():
        """
        测试分类效果,即the error rate
        """
        hoRatio = 0.1
        datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
        normMat, ranges, minVals = autoNorm(datingDataMat)
        m = normMat.shape[0]
        numTestVecs = int(m*hoRatio)
        errorCount = 0.0
        for i in range(numTestVecs):
            classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :],
                                         datingLabels[numTestVecs:m], 3)
            print("the classifier came back with: %d, the real answer is: %d" % (classifierResult,
                                                                                 datingLabels[i]))
            if classifierResult != datingLabels[i]:
                errorCount += 1.0
        print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
    
    
    def classifyPerson():
        """
        构建完整可用系统,即约会网站预测函数
        """
        resultList = ['not at all', 'in small doses', 'in large doses']
        percentTats = float(input("percentage of time spent playing video games?  "))
        ffMiles = float(input("frequent flier miles earned per year?  "))
        iceCream = float(input("liters of ice cream consumed per year?  "))
        datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
        normMat, ranges, minVals = autoNorm(datingDataMat)
        inArr = np.array([ffMiles, percentTats, iceCream])
        classifierResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)
        print("You will probably like this person:", resultList[classifierResult - 1])
    
    
    def img2vector(filename):
        """
        手写识别将图像转换为测试向量
        :param filename:
        :return:
        """
        returnVect = np.zeros((1, 1024))
        fr = open(filename)
        for i in range(32):
            lineStr = fr.readline()
            for j in range(32):
                returnVect[0, 32*i+j] = int(lineStr[j])
        return returnVect
    
    testVector = img2vector('digits/testDigits/0_13.txt')
    
    
    def handwritingClassTest():
        """
        手写数字识别系统的测试代码
        """
        hwLabels = []
        trainingFileList = os.listdir('digits/trainingDigits')
        m = len(trainingFileList)
        trainingMat = np.zeros((m, 1024))
        for i in range(m):
            fileNameStr = trainingFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            hwLabels.append(classNumStr)
            trainingMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr)
    
        testFileList = os.listdir('digits/testDigits')
        errorCount = 0.0
        mTest = len(testFileList)
        for i in range(mTest):
            fileNameStr = testFileList[i]
            fileStr = fileNameStr.split('.')[0]
            classNumStr = int(fileNameStr.split('_')[0])
            vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
            classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
            print("the classifier came back with: %d, the real answer is: %d" % (classifierResult,
                                                                                 classNumStr))
            if classifierResult != classNumStr:
                errorCount += 1.0
        print("
    the total number of errors is: %d" % errorCount)
        print("
    the total error rate is: %f" % (errorCount/float(mTest)))
    
    
    handwritingClassTest()
  • 相关阅读:
    【Linux】项目部署
    【架构师之路】【MQ】消息队列
    【数据库】【Python】mysql
    【算法】【Python】找出字符串中重复出现的字符 并求出重复次数 且根据重复次数从大到小排列
    【Python】排序 按照list中的字典的某key排序
    Kettle Post请求webservice
    python+pytest+allure接口自动化测试框架
    Python+unittest+requests+htmlTestRunner+excel完整的接口自动化框架
    python实现栈的基本操作
    展示博客园顶部的随笔、文章、评论、阅读量统计数据
  • 原文地址:https://www.cnblogs.com/key221/p/9741236.html
Copyright © 2011-2022 走看看