zoukankan      html  css  js  c++  java
  • 机器学习实战-KNN(K-近邻算法)详解

    KNN(K-近邻算法):算法本身是一个有监督学习的算法,故训练数据是有标签的,算法的原理是计算测试数据距离训练数据的距离(一般是欧式距离),将计算出的距离进行从小到大的排序,取前K个距离对应的训练数据,计算这K个数据中不同标签所占比例,比例最高的标签即为测试数据所属于的类

    以下为python实现K-近邻算法详解:

      1 #coding:UTF-8
      2 
      3 from numpy import *
      4 import operator
      5 
      6 def createDataSet():
      7     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
      8     labels = ['A','A','B','B']
      9     return group,labels
     10 
     11 
     12 #k-近邻算法
     13 #inX 用于分类的输入向量
     14 #dataSet 训练样本集
     15 #labels  标签向量
     16 # 选择最近邻居的数目
     17 
     18 def classify0(inX,dataSet,labels,k):
     19     #shape 返回数组的行列数  shape[0]为数组的行数
     20     dataSetSize = dataSet.shape[0]
     21 
     22     #tile(x,(dataSetSize,1)) tile方法是将数组x 在列方向上复制1行,行方向上复制dateSetSize次
     23     # >>> import numpy
     24     # >>> numpy.tile([0,0],5) 在列方向上重复[0,0] 5次,行默认1次
     25     # array([0,0,0,0,0,0,0,0,0,0])
     26     # >>> numpy.tile([0,0],(1,1)) 复制[0,0] 在列方向上1次  在行方向上1次
     27     # array([[0,0]])
     28     # >>> numpy.tile([0,0],(2,1)) 在列方向上复制1次 行方向上复制2次
     29     # array([[0,0],
     30     #     [0,0])
     31     # >>> numpy.tile([0,0],(3,1)) 列1次 行3次
     32     # array([[0, 0],
     33     #        [0, 0],
     34     #        [0, 0]])
     35     # >>> numpy.tile([0,0],(1,3)) 列3次 行1次
     36     # array([[0,0,0,0,0,0]])
     37     # >>> numpy.tile([0,0],(2,3)) 列3次 行2次
     38     # array([[0,0,0,0,0,0]
     39     #       [0,0,0,0,0,0]])
     40 
     41     
     42     #复制后计算差值
     43     diffMat = tile(inX, (dataSetSize,1)) - dataSet
     44 
     45     #将diffMat数组中的每个元素进行平方
     46     sqDiffMat = diffMat ** 2
     47 
     48     #将数组sqDiffMat按行相加 axis=1 表示按照横轴 sum表示累加
     49     sqDistances = sqDiffMat.sum(axis = 1)
     50     
     51     #将sqDistances开根号
     52     distances = sqDistances ** 0.5
     53 
     54 
     55     #以上部分即为欧式距离公式 计算两个向量点之间的距离
     56     # (1)二维平面上两点a(x1,y1)与b(x2,y2)间的欧氏距离:x1 - x2的平方加上 y1-y2的平方 然后开根号 
     57 
     58 
     59 
     60     # 按照升序进行快速排序,返回的是原数组的下标。  
     61     # 比如,x = [30, 10, 20, 40]  
     62     # 升序排序后应该是[10,20,30,40],他们的原下标是[1,2,0,3]  
     63     # 那么,numpy.argsort(x) = [1, 2, 0, 3] 
     64     sortedDistIndicies = distances.argsort()
     65     
     66     
     67     #存放最终的分类结果及相应的结果投票数的字典
     68     classCount = {}
     69 
     70     #统计前k个最近的样本所属类别包含的样本个数
     71     for i in range(k):
     72         #sortedDistIndicies[i] 第i个样本的下标
     73         #voteIlabel = labels[sortedDistIndicies[i]] 是对应labels的结果(“A” or “B”) 
     74         voteIlabel = labels[sortedDistIndicies[i]]
     75         
     76         #classCount.get(voteIlabel,0) 返回voteIlabel的值 不存在则返回0
     77         #然后将对应结果加1
     78         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
     79     #sorted 排序函数
     80     #(1)第一个参数是要排序的list或者lables   classCount.iteritems()表示迭代输出字典的键值对这里应该是类似{"A":3,"B:2"} 这样的字典
     81     #(2)key为函数,指定取待排序元素的哪一项进行排序
     82     #operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为一些序号(即需要获取的数据在对象中的序号)
     83     #>>> import operator
     85     #>>> a = [1,2,3]
     86     #>>> b = operator.itemgetter(1) 定义函数b,获取对象的第1个域的值
     87     #>>> b(a)
     88     #2
     89     #>>> b = operator.itemgetter(1,0) 定义函数b,获取对象的第一个域和第0个值
     90     #>>> b(a)
     91     #(2, 1)
     92     #>>> b = operator.itemgetter(1,1)
     93     #>>> b(a)
     94     #(2, 2)
     95     #这里是{"A":3,"B:2"}的对象 所以应该按照values排序即随想的第一个域即operator.itemgetter(1)
     96     #要注意,operator.itemgetter函数获取的不是值,而是定义了一个函数,通过该函数作用到对象上才能获取值
     97     #(3)reverse 排序规则 默认False 升序排列  True 降序排列
     98     #(4)返回值:是一个经过排序的可迭代类型,与iterable一样
     99     sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True)
    100     #返回数量最多的类
    101     return sortedClassCount[0][0]    
    102             
    103 if __name__== "__main__":  
    104     # 导入数据  
    105     dataset, labels = createDataSet()  
    106     inX = [0.1, 0.1]  
    107     # 简单分类  
    108     className = classify0(inX, dataset, labels, 3)  
    109     print 'the class of test sample is %s' %className    
  • 相关阅读:
    luogu 1865 数论 线性素数筛法
    洛谷 2921 记忆化搜索 tarjan 基环外向树
    洛谷 1052 dp 状态压缩
    洛谷 1156 dp
    洛谷 1063 dp 区间dp
    洛谷 2409 dp 月赛题目
    洛谷1199 简单博弈 贪心
    洛谷1417 烹调方案 dp 贪心
    洛谷1387 二维dp 不是特别简略的题解 智商题
    2016 10 28考试 dp 乱搞 树状数组
  • 原文地址:https://www.cnblogs.com/oceanL/p/6635125.html
Copyright © 2011-2022 走看看