zoukankan      html  css  js  c++  java
  • 机器学习实战-KNN(K-近邻算法)详解

    KNN(K-近邻算法):算法本身是一个有监督学习的算法,故训练数据是有标签的,算法的原理是计算测试数据距离训练数据的距离(一般是欧式距离),将计算出的距离进行从小到大的排序,取前K个距离对应的训练数据,计算这K个数据中不同标签所占比例,比例最高的标签即为测试数据所属于的类

    以下为python实现K-近邻算法详解:

      1 #coding:UTF-8
      2 
      3 from numpy import *
      4 import operator
      5 
      6 def createDataSet():
      7     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
      8     labels = ['A','A','B','B']
      9     return group,labels
     10 
     11 
     12 #k-近邻算法
     13 #inX 用于分类的输入向量
     14 #dataSet 训练样本集
     15 #labels  标签向量
     16 # 选择最近邻居的数目
     17 
     18 def classify0(inX,dataSet,labels,k):
     19     #shape 返回数组的行列数  shape[0]为数组的行数
     20     dataSetSize = dataSet.shape[0]
     21 
     22     #tile(x,(dataSetSize,1)) tile方法是将数组x 在列方向上复制1行,行方向上复制dateSetSize次
     23     # >>> import numpy
     24     # >>> numpy.tile([0,0],5) 在列方向上重复[0,0] 5次,行默认1次
     25     # array([0,0,0,0,0,0,0,0,0,0])
     26     # >>> numpy.tile([0,0],(1,1)) 复制[0,0] 在列方向上1次  在行方向上1次
     27     # array([[0,0]])
     28     # >>> numpy.tile([0,0],(2,1)) 在列方向上复制1次 行方向上复制2次
     29     # array([[0,0],
     30     #     [0,0])
     31     # >>> numpy.tile([0,0],(3,1)) 列1次 行3次
     32     # array([[0, 0],
     33     #        [0, 0],
     34     #        [0, 0]])
     35     # >>> numpy.tile([0,0],(1,3)) 列3次 行1次
     36     # array([[0,0,0,0,0,0]])
     37     # >>> numpy.tile([0,0],(2,3)) 列3次 行2次
     38     # array([[0,0,0,0,0,0]
     39     #       [0,0,0,0,0,0]])
     40 
     41     
     42     #复制后计算差值
     43     diffMat = tile(inX, (dataSetSize,1)) - dataSet
     44 
     45     #将diffMat数组中的每个元素进行平方
     46     sqDiffMat = diffMat ** 2
     47 
     48     #将数组sqDiffMat按行相加 axis=1 表示按照横轴 sum表示累加
     49     sqDistances = sqDiffMat.sum(axis = 1)
     50     
     51     #将sqDistances开根号
     52     distances = sqDistances ** 0.5
     53 
     54 
     55     #以上部分即为欧式距离公式 计算两个向量点之间的距离
     56     # (1)二维平面上两点a(x1,y1)与b(x2,y2)间的欧氏距离:x1 - x2的平方加上 y1-y2的平方 然后开根号 
     57 
     58 
     59 
     60     # 按照升序进行快速排序,返回的是原数组的下标。  
     61     # 比如,x = [30, 10, 20, 40]  
     62     # 升序排序后应该是[10,20,30,40],他们的原下标是[1,2,0,3]  
     63     # 那么,numpy.argsort(x) = [1, 2, 0, 3] 
     64     sortedDistIndicies = distances.argsort()
     65     
     66     
     67     #存放最终的分类结果及相应的结果投票数的字典
     68     classCount = {}
     69 
     70     #统计前k个最近的样本所属类别包含的样本个数
     71     for i in range(k):
     72         #sortedDistIndicies[i] 第i个样本的下标
     73         #voteIlabel = labels[sortedDistIndicies[i]] 是对应labels的结果(“A” or “B”) 
     74         voteIlabel = labels[sortedDistIndicies[i]]
     75         
     76         #classCount.get(voteIlabel,0) 返回voteIlabel的值 不存在则返回0
     77         #然后将对应结果加1
     78         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
     79     #sorted 排序函数
     80     #(1)第一个参数是要排序的list或者lables   classCount.iteritems()表示迭代输出字典的键值对这里应该是类似{"A":3,"B:2"} 这样的字典
     81     #(2)key为函数,指定取待排序元素的哪一项进行排序
     82     #operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为一些序号(即需要获取的数据在对象中的序号)
     83     #>>> import operator
     85     #>>> a = [1,2,3]
     86     #>>> b = operator.itemgetter(1) 定义函数b,获取对象的第1个域的值
     87     #>>> b(a)
     88     #2
     89     #>>> b = operator.itemgetter(1,0) 定义函数b,获取对象的第一个域和第0个值
     90     #>>> b(a)
     91     #(2, 1)
     92     #>>> b = operator.itemgetter(1,1)
     93     #>>> b(a)
     94     #(2, 2)
     95     #这里是{"A":3,"B:2"}的对象 所以应该按照values排序即随想的第一个域即operator.itemgetter(1)
     96     #要注意,operator.itemgetter函数获取的不是值,而是定义了一个函数,通过该函数作用到对象上才能获取值
     97     #(3)reverse 排序规则 默认False 升序排列  True 降序排列
     98     #(4)返回值:是一个经过排序的可迭代类型,与iterable一样
     99     sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True)
    100     #返回数量最多的类
    101     return sortedClassCount[0][0]    
    102             
    103 if __name__== "__main__":  
    104     # 导入数据  
    105     dataset, labels = createDataSet()  
    106     inX = [0.1, 0.1]  
    107     # 简单分类  
    108     className = classify0(inX, dataset, labels, 3)  
    109     print 'the class of test sample is %s' %className    
  • 相关阅读:
    第一节,Django+Xadmin打造上线标准的在线教育平台—创建用户app,在models.py文件生成3张表,用户表、验证码表、轮播图表
    Tensorflow 错误:Unknown command line flag 'f'
    Python 多线程总结
    Git 强制拉取覆盖本地所有文件
    Hive常用函数 傻瓜学习笔记 附完整示例
    Linux 删除指定大小(范围)的文件
    Python 操作 HBase —— Trift Trift2 Happybase 安装使用
    梯度消失 梯度爆炸 梯度偏置 梯度饱和 梯度死亡 文献收藏
    Embedding 文献收藏
    深度学习在CTR预估中的应用 文献收藏
  • 原文地址:https://www.cnblogs.com/oceanL/p/6635125.html
Copyright © 2011-2022 走看看