zoukankan      html  css  js  c++  java
  • 机器学习-简单介绍入门算法kNN

    基于Peter Harrington所著《Machine Learning in Action》

    kNN,即k-NearestNeighbor算法,是一种最简单的分类算法,拿这个当机器学习、数据挖掘的入门实例是非常合适的。

    简单的解释一下kNN的原理和目的:

    假设有一种数据,每一条有两个特征值,这些数据总共有两大类,例如:

    [ [1 , 1.1] , [ 1 , 1 ] , [0 , 0 ] , [0 , 0.1] ] 这四个数据(训练数据),种类分别为[ 'A' , 'A' , 'B' ,'B' ]。

    现在给出一条数据X=[1.1 , 1.1],需要判断这条数据属于A还是B,这时候就可以用kNN来判断。当然现实中每个数据可能有很多个特征,总共也有很多分类,这里以最简单的方式来举例。

    原理也非常简单,将上述训练数据放到坐标轴中,然后计算X到每个训练数据的距离,从近到远做个排序,选取其中的前N条,判断其中是属于A类的数据多还是B类的多,如果属于A类的多,那可以认为X属于A;反之亦然。

    下面就用具体的代码来演示一下上面陈述的算法。(基于py,建议直接装anaconda,一劳永逸)

    建立一个py文件,名称随意。

    先是创建训练数据,这里用py数组代替,实际可能是一堆文本或其他格式

    def createDataSet():
        group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
        lables=['A','A','B','B']
        return group,lables

    group是训练数据,lables是group的分类数据。

    kNN算法

     1 def classify0(inX,dataSet,labels,k):
     2     dataSetSize = dataSet.shape[0]
     3     diffMat=tile(inX,(dataSetSize,1))-dataSet
     4     sqDiffMat=diffMat**2
     5     sqDistances= sqDiffMat.sum(axis=1)
     6     distances = sqDistances**0.5
     7     sortedDistIndicies = distances.argsort()
     8     classCount={}
     9     for i in range(k):
    10         voteIlabel = labels[sortedDistIndicies[i]]
    11         classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
    12     sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    13     return  sortedClassCount[0][0]

    inX就是需要判断的数据X,dataSet就是上面构建出来的group,labels是group的分类,k是“从近到远做个排序,选取其中的前N条”中的N,这个会对结果的准确性有一定影响。

    这里用以下测试数据结合kNN算法进行讲解。

    group,labels = createDataSet()
    inX= [1.1,1.10]
    print(classify0(inX,group,labels,3))

    第2行是计算dataSet的总行数,此时为4。下面3,4,5,6行就是计算inX到每个训练数据的距离的,用的是欧式距离公式0ρ = sqrt( (x1-x2)^2+(y1-y2)^2 ),高中都学过。只是用矩阵和python的形式写出来可能一时不好看明白。

    3 tile(inX,(dataSetSize,1)) 构建出一个每一行都是inX,有dateSetSize行的矩阵,具体数据如下:

    [[ 1.1  ,1.1]
    [ 1.1  ,1.1]
    [ 1.1  ,1.1]
    [ 1.1  ,1.1]]

    再用这个矩阵减去dataSet,则会得到inX和每个训练数据的x,y坐标上的差值。最终的diffMat如下,这就是inX对每个训练数据的 x1-x2,y1-y2:

    [[ 0.1  ,0. ]
    [ 0.1  ,0.1]
    [ 1.1  ,1.1]
    [ 1.1  ,1. ]]

    4就是对矩阵做个平方,得到结果如下,即(x1-x2)^2,(y1-y2)^2:

    [[ 0.01  ,0. ]
    [ 0.01  ,0.01]
    [ 1.21  ,1.21]
    [ 1.21  ,1. ]]

    5就是把矩阵横向相加,得到结果就是(x1-x2)^2+(y1-y2)^2:

    [ 0.01  ,0.02 , 2.42,  2.21]

    6就是对5得到的(x1-x2)^2+(y1-y2)^2进行开根号,得到inX到每个训练数据的距离,结果如下:

    [ 0.1  ,       0.14142136 , 1.55563492 , 1.48660687]

    7是对6做个排序,9-11就是选出和inX距离最近的前k个点,统计这k个点中有几个属于A,有几个属于B。在本例中,得到的sortedClassCount为:

     [('A', 2), ('B', 1)]

    也就是和inX最近的三个点有两个属于A,一个B。这是,就可以认为inX是属于A类的了。

  • 相关阅读:
    Csharp: create Transparent Images in winform
    HTML5:Subway Map Visualization jQuery Plugin(示例畫深圳地鐵線路圖)
    sql 语句 查询 sql server 主键!
    面向对象学习
    聚类算法学习笔记(一)——基础
    oracle 会话以及处理数
    java.util.Calendar常量字段值
    java连接sql时候,获取表格各列属性
    Oracle 动态SQL返回单条结果和结果集
    Oracle数据库数据字典学习
  • 原文地址:https://www.cnblogs.com/csonezp/p/8567990.html
Copyright © 2011-2022 走看看