zoukankan      html  css  js  c++  java
  • 机器学习 MLIA学习笔记(三)之 KNN(二) Dating可能性实例

    这是个KNN算法的另一实例,计算Dating的可能性。

    import numpy as np
    import os
    import operator
    import matplotlib
    import matplotlib.pyplot as plt
    
    def classify(inX, dataSet, labels, k):
        dataSetSize = dataSet.shape[0]#lines num; samples num
        diffMat = np.tile(inX, (dataSetSize,1)) - dataSet#dataSize*(1*inX)
        sqDiffMat = diffMat**2
        sqDistances = sqDiffMat.sum(axis=1)#add as the first dim
        distances = sqDistances**0.5
        #return indicies array from min to max
        #this is an array
        sortedDistanceIndices = distances.argsort()
        #classCount={}
        classCount=dict()   #define a dictionary
        for i in range(k):
            voteIlabel = labels[sortedDistanceIndices[i]]
            classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1#get(key,default=none)
        #return a list like [('C',4),('B',3),('A',2)], not a dict
        #itemgetter(0) is the 1st element
        #default: from min to max
        sortedClassCount = sorted(classCount.iteritems(),
                                  key=operator.itemgetter(1), reverse=True)                  
        return sortedClassCount[0][0]
    
    def file2matrix(fileName):
        fileHandler = open(fileName)
        numberOfLines = len(fileHandler.readlines())    #get the number of lines in the file
        returnMat = np.zeros((numberOfLines, 3))           #init a zero return matrix
        classLabelVector = []
        #classLabelVector = list()                       #will be used to record labels
        fileHandler = open(fileName)
        index = 0
        for line in fileHandler.readlines():
            line = line.strip()                         #strip blank characters
            listFromLine = line.split('	')
            returnMat[index,:] = listFromLine[0:3]
            classLabelVector.append(listFromLine[-1])
            index += 1
        return returnMat, classLabelVector
    
    #normalize data set
    def autoNorm(dataSet):
        minVal = dataSet.min(0)
        maxVal = dataSet.max(0)
        ranges = maxVal - minVal
        normDataSet = np.zeros(np.shape(dataSet))
        m = dataSet.shape[0]
        normDataSet = dataSet - np.tile(minVal, (m,1))
        normDataSet = normDataSet/np.tile(ranges, (m,1))
        return normDataSet, ranges, minVal
    
    def showMatrix():
        m,l = file2matrix("datingTestSet.txt")
        m,r,mv = autoNorm(m)
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.scatter(m[:,1],m[:,2])
        plt.show()
    
    #calculate the error rate of sample
    def calcErrorRate():
        ratio = 0.1         #only use 10% samples to calc the error rate
        matrix,l = file2matrix("datingTestSet.txt")
        matrix,r,mv = autoNorm(matrix)
        m = matrix.shape[0]
        numTestSample = int(m*ratio)
        errorCount = 0
        for i in range(numTestSample):
            classifyResult = classify(matrix[i,:], matrix[numTestSample:m,:],l[numTestSample:m],3)
            print "the classifier came back with: %s, the real answer is: %s" % (classifyResult, l[i])
            if (classifyResult != l[i]):
                errorCount += 1
        print "the total error rate is: %f" %(errorCount/float(numTestSample))
        print errorCount
    
    def classifyPerson():
        percentTats = float(raw_input(
                    "percentage of time spent playing vedio games?"))
        ffMiles = float(raw_input("frequent flier miles earned per year?"))
        iceCream = float(raw_input("liters of ice cream consumed per year?"))
        datingDataMat, datingLabels = file2matrix("datingTestSet.txt")
        normMat, ranges, minVal = autoNorm(datingDataMat)
        inArr = np.array([ffMiles, percentTats, iceCream])
        classifyResult = classify((inArr-minVal)/ranges, normMat, datingLabels,3)
        print "You will probaly like this person: ", classifyResult
  • 相关阅读:
    枚举工具类:封装判断是否存在这个枚举
    MYSQL插入emoji报错解决方法Incorrect string value
    文件大小转换带上单位工具类(文件byte自动转KBMBGB)
    mysql 统计七天数据并分组
    mybatis plus 和 druid 版本导致LocalDateTime 不兼容问题
    Layui弹框中select下拉列表赋值回显
    查看环境版本
    Linux 常用命令
    安装jdk14的坑
    modbus_tk解析
  • 原文地址:https://www.cnblogs.com/AmitX-moten/p/4176728.html
Copyright © 2011-2022 走看看