zoukankan      html  css  js  c++  java
  • 【机器学习实战】k-近邻算法2.2约会网站预测函数

    《机器学习实战》学习

    书中使用Python2进行代码演示,我这里将其转换为Python3,并做了一些注释。要学会使用断点调试,方便很多

    下面的代码是书中2.2节使用k-近邻算法改进约会网站的配对效果的完整测试代码:

      1 from numpy import *
      2 import operator
      3 import matplotlib
      4 import matplotlib.pyplot as plt
      5 
      6 
      7 def createDataSet():
      8     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
      9     labels = ['A', 'A', 'B', 'B']
     10     return group, labels
     11 
     12 
     13 def classify0(inX, dataSet, labels, k):
     14     '''
     15     k-近邻算法
     16     :param inX:用于分类的输入向量
     17     :param dataSet: 输入的训练样本集
     18     :param labels: 标签向量
     19     :param k: 用于选择最近邻居的数目
     20     :return: 返回k个邻居中距离最近且数量最多的类别作为预测类别
     21     '''
     22     dataSetSize = dataSet.shape[0]
     23     diffMat = tile(inX, (dataSetSize, 1)) - dataSet
     24     sqDiffMat = diffMat ** 2
     25     sqDistances = sqDiffMat.sum(axis=1)
     26     distances = sqDistances ** 0.5
     27     # 以上为计算输入向量与已有标签样本的欧式距离
     28     sortedDistIndicies = distances.argsort()  # argsort函数返回的是数组值从小到大的索引值,距离需要从小到大排序
     29     classCount = {}
     30     for i in range(k):
     31         voteIlabel = labels[sortedDistIndicies[i]]
     32         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
     33         # Python 字典(Dictionary) get() 函数返回指定键的值,如果值不在字典中返回默认值。
     34         # get(voteIlabel,0)表示当能查询到相匹配的字典时,就会显示相应key对应的value,如果不能的话,就会显示后面的这个参数。
     35     sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
     36     # 按照元祖中第2个值的大小降序排序
     37     # python2中的iteritems()方法需改为items()
     38     return sortedClassCount[0][0]
     39 
     40 
     41 def file2matrix(filename):
     42     # 将文本记录转换为NumPy的解析程序
     43     fr = open(filename)
     44     arrayOLines = fr.readlines()
     45     numberOfLines = len(arrayOLines)
     46     print(numberOfLines)
     47     returnMat = zeros((numberOfLines,3))  # 存放3种特征
     48     classLabelVector = []  # 存放标签
     49     index = 0
     50     for line in arrayOLines:
     51         line = line.strip()  # strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。
     52         listFromLine = line.split('	')  # split() 通过指定分隔符对字符串进行切片,组成列表
     53         returnMat[index, :] = listFromLine[0:3]  # 将当前列表的前3个值赋予returnMat的当前行
     54         classLabelVector.append(int(listFromLine[-1]))  # 将标签添加到classLabelVector中
     55         index += 1
     56     return returnMat, classLabelVector
     57 
     58 
     59 def autoNum(dataSet):
     60     minVals = dataSet.min(0)  # A.min(0) : 返回A每一列最小值组成的一维数组;
     61     maxVals = dataSet.max(0)  # A.max(0):返回A每一列最大值组成的一维数组;
     62     # https://blog.csdn.net/qq_41800366/article/details/86313052
     63     ranges = maxVals - minVals
     64     normDataSet = zeros(shape(dataSet))
     65     m = dataSet.shape[0]
     66     normDataSet = dataSet - tile(minVals, (m,1))
     67     # tile将minVals的行数乘以m次重复,列数乘以1次重复,每一行都减掉minVals
     68     normDataSet = normDataSet/tile(ranges,(m,1))
     69     # 每一行都除以ranges以是实现数据归一化
     70     return normDataSet,ranges, minVals
     71 
     72 
     73 def datingClassTest():
     74     hoRatio = 0.10  # 测试集比重
     75     m = normMat.shape[0]
     76     numTestVecs = int(m*hoRatio)  # 测试集数量
     77     errorCount = 0.0
     78     for i in range(numTestVecs):
     79         classfierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
     80         print("the classifierResult came back with: %d,the real answer is: %d"%(classfierResult,datingLabels[i]))
     81         if(classfierResult != datingLabels[i]):errorCount += 1.0
     82         print("the total error rate is: %f"%(errorCount/float(numTestVecs)))
     83 
     84 def classifyPerson():
     85     resultList = ['not at all', 'in small doses', 'in large doses']
     86     percentTats = float(input("percentage of time spent playing video games?"))
     87     # 在 Python3.x 中 raw_input( ) 和 input( ) 进行了整合,去除了 raw_input( ),仅保留了 input( ) 函数,
     88     # 其接收任意任性输入,将所有输入默认为字符串处理,并返回字符串类型。
     89     ffMiles = float(input("frequent flier miles earned per year?"))
     90     iceCream = float(input("liters of ice cream consumed per year?"))
     91     inArr = array([ffMiles, percentTats, iceCream])  # 输入测试向量
     92     classifierResult = classify0((inArr-minVals)/ranges,normMat,datingLabels,3)  # 得到分类结果
     93     print("You will probably like this person:",resultList[classifierResult-1])
     94 
     95 
     96 if __name__ == "__main__":
     97     '''
     98     group,labels = createDataSet()
     99     result = classify0([0,0],group,labels,3)
    100     print(result)
    101     '''
    102     datingDataMat, datingLabels = file2matrix("./datingTestSet2.txt")  # 数据转换
    103     # print(datingDataMat)
    104     # print(datingLabels)
    105     fig = plt.figure()
    106     ax = fig.add_subplot()
    107     ax.scatter(datingDataMat[:,0],datingDataMat[:,1],15.0*array(datingLabels),15.0*array(datingLabels))
    108     plt.show()
    109     normMat, ranges, minVals = autoNum(datingDataMat)  # 输入数据归一化
    110     # print(normMat)
    111     # print(ranges)
    112     # print(minVals)
    113     # datingClassTest()
    114     classifyPerson()  # 分类
    运行结果:
    1
    percentage of time spent playing video games?10 2 frequent flier miles earned per year?4000 3 liters of ice cream consumed per year?1 4 You will probably like this person: in small doses
    # python2中的iteritems()方法需改为items()
  • 相关阅读:
    Wappalyzer(chrome网站分析插件)
    轻松搞定项目中的空指针异常Caused by: java.lang.NullPointerException: null
    一则sql优化实现接口耗时降低30倍的优化案例
    测试环境部署之填坑记录-Expected one result (or null) to be returned by selectOne(), but found: 2
    性能优化案例(2019-案例78)-接口性能耗时问题分析
    Unitest自动化测试基于HTMLTestRunner报告案例
    scrapy实例:爬取天气、气温等
    Python3爬取豆瓣网电影信息
    Locust压测结果准确性验证
    jd-gui反编译报错// INTERNAL ERROR //
  • 原文地址:https://www.cnblogs.com/DJames23/p/13053974.html
Copyright © 2011-2022 走看看