zoukankan      html  css  js  c++  java
  • 机器学习实战-KNN(K-近邻算法)详解

    KNN(K-近邻算法):算法本身是一个有监督学习的算法,故训练数据是有标签的,算法的原理是计算测试数据距离训练数据的距离(一般是欧式距离),将计算出的距离进行从小到大的排序,取前K个距离对应的训练数据,计算这K个数据中不同标签所占比例,比例最高的标签即为测试数据所属于的类

    以下为python实现K-近邻算法详解:

      1 #coding:UTF-8
      2 
      3 from numpy import *
      4 import operator
      5 
      6 def createDataSet():
      7     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
      8     labels = ['A','A','B','B']
      9     return group,labels
     10 
     11 
     12 #k-近邻算法
     13 #inX 用于分类的输入向量
     14 #dataSet 训练样本集
     15 #labels  标签向量
     16 # 选择最近邻居的数目
     17 
     18 def classify0(inX,dataSet,labels,k):
     19     #shape 返回数组的行列数  shape[0]为数组的行数
     20     dataSetSize = dataSet.shape[0]
     21 
     22     #tile(x,(dataSetSize,1)) tile方法是将数组x 在列方向上复制1行,行方向上复制dateSetSize次
     23     # >>> import numpy
     24     # >>> numpy.tile([0,0],5) 在列方向上重复[0,0] 5次,行默认1次
     25     # array([0,0,0,0,0,0,0,0,0,0])
     26     # >>> numpy.tile([0,0],(1,1)) 复制[0,0] 在列方向上1次  在行方向上1次
     27     # array([[0,0]])
     28     # >>> numpy.tile([0,0],(2,1)) 在列方向上复制1次 行方向上复制2次
     29     # array([[0,0],
     30     #     [0,0])
     31     # >>> numpy.tile([0,0],(3,1)) 列1次 行3次
     32     # array([[0, 0],
     33     #        [0, 0],
     34     #        [0, 0]])
     35     # >>> numpy.tile([0,0],(1,3)) 列3次 行1次
     36     # array([[0,0,0,0,0,0]])
     37     # >>> numpy.tile([0,0],(2,3)) 列3次 行2次
     38     # array([[0,0,0,0,0,0]
     39     #       [0,0,0,0,0,0]])
     40 
     41     
     42     #复制后计算差值
     43     diffMat = tile(inX, (dataSetSize,1)) - dataSet
     44 
     45     #将diffMat数组中的每个元素进行平方
     46     sqDiffMat = diffMat ** 2
     47 
     48     #将数组sqDiffMat按行相加 axis=1 表示按照横轴 sum表示累加
     49     sqDistances = sqDiffMat.sum(axis = 1)
     50     
     51     #将sqDistances开根号
     52     distances = sqDistances ** 0.5
     53 
     54 
     55     #以上部分即为欧式距离公式 计算两个向量点之间的距离
     56     # (1)二维平面上两点a(x1,y1)与b(x2,y2)间的欧氏距离:x1 - x2的平方加上 y1-y2的平方 然后开根号 
     57 
     58 
     59 
     60     # 按照升序进行快速排序,返回的是原数组的下标。  
     61     # 比如,x = [30, 10, 20, 40]  
     62     # 升序排序后应该是[10,20,30,40],他们的原下标是[1,2,0,3]  
     63     # 那么,numpy.argsort(x) = [1, 2, 0, 3] 
     64     sortedDistIndicies = distances.argsort()
     65     
     66     
     67     #存放最终的分类结果及相应的结果投票数的字典
     68     classCount = {}
     69 
     70     #统计前k个最近的样本所属类别包含的样本个数
     71     for i in range(k):
     72         #sortedDistIndicies[i] 第i个样本的下标
     73         #voteIlabel = labels[sortedDistIndicies[i]] 是对应labels的结果(“A” or “B”) 
     74         voteIlabel = labels[sortedDistIndicies[i]]
     75         
     76         #classCount.get(voteIlabel,0) 返回voteIlabel的值 不存在则返回0
     77         #然后将对应结果加1
     78         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
     79     #sorted 排序函数
     80     #(1)第一个参数是要排序的list或者lables   classCount.iteritems()表示迭代输出字典的键值对这里应该是类似{"A":3,"B:2"} 这样的字典
     81     #(2)key为函数,指定取待排序元素的哪一项进行排序
     82     #operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为一些序号(即需要获取的数据在对象中的序号)
     83     #>>> import operator
     85     #>>> a = [1,2,3]
     86     #>>> b = operator.itemgetter(1) 定义函数b,获取对象的第1个域的值
     87     #>>> b(a)
     88     #2
     89     #>>> b = operator.itemgetter(1,0) 定义函数b,获取对象的第一个域和第0个值
     90     #>>> b(a)
     91     #(2, 1)
     92     #>>> b = operator.itemgetter(1,1)
     93     #>>> b(a)
     94     #(2, 2)
     95     #这里是{"A":3,"B:2"}的对象 所以应该按照values排序即随想的第一个域即operator.itemgetter(1)
     96     #要注意,operator.itemgetter函数获取的不是值,而是定义了一个函数,通过该函数作用到对象上才能获取值
     97     #(3)reverse 排序规则 默认False 升序排列  True 降序排列
     98     #(4)返回值:是一个经过排序的可迭代类型,与iterable一样
     99     sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True)
    100     #返回数量最多的类
    101     return sortedClassCount[0][0]    
    102             
    103 if __name__== "__main__":  
    104     # 导入数据  
    105     dataset, labels = createDataSet()  
    106     inX = [0.1, 0.1]  
    107     # 简单分类  
    108     className = classify0(inX, dataset, labels, 3)  
    109     print 'the class of test sample is %s' %className    
  • 相关阅读:
    HUD --- 3635
    leetcode380- Insert Delete GetRandom O(1)- medium
    leetcode68- Text Justification- hard
    leetcode698- Partition to K Equal Sum Subsets- medium
    leetcode671- Second Minimum Node In a Binary Tree- easy
    leetcode647- Palindromic Substrings- medium
    leetcode633- Sum of Square Numbers- easy
    leetcode605- Can Place Flowers- easy
    leetcode515- Find Largest Value in Each Tree Row- medium
    leetcode464- Can I Win- medium
  • 原文地址:https://www.cnblogs.com/oceanL/p/6635125.html
Copyright © 2011-2022 走看看