zoukankan      html  css  js  c++  java
  • kNN算法笔记

    kNN算法笔记

    标签(空格分隔): 机器学习


    kNN是什么

    kNN算法是k-NearestNeighbor算法,也就是k邻近算法。是监督学习的一种。所谓监督学习就是有训练数据,训练数据有label标好(也就是分类分好的)。kNN的思路是,对于需要测试的数据,把它和训练集中的每个数据都进行距离计算,距离最近的前k个结果中,所对应的label出现次数最多的,就是这个测试数据所属的label(类别)。

    kNN一般步骤

    按照《machine learning in action》一书中的通用步骤走一遍:

    1. 计算已知类别数据集中的点与当前点之间的距离
    2. 按照距离递增次序排序
    3. 选取与当前点距离最小的k个点
    4. 确定前k个点所在类别的出现频率
    5. 返回前k个点出现频率最高的类别作为当前点的预测分类

    不过按偏向程序编写的角度来说,是:

    1. 准备训练数据集,包括向量数据A和对应的标签(也就是分类)
    2. 将待测数据构造成和训练集同样数量的一个矩阵,也就是自重复,得到B
    3. 计算A和B中每一个向量之间的距离:每一维对应相减,然后求平方和,在开根号,得到距离序列D。
    4. 从距离序列D中选取排序后的前k个,得到序列E
    5. 逐一扫描序列E,把每个元素对应的label作为key,label出现的次数作为value,边扫描边生成这样的(key,value)构成的字典F
    6. 将字典F按照value排序,排序后的第一个元素value最大,表示对应的key也就是label(类别)出现次数最多,作为测试数据的预测归类。

    kNN代码编写

    假设有四个点P1(1,1.1),P2(1,1),P3(0,0),P4(0,0.1),分别属于A、A类和B、B类。现在想测试Q(0,0)属于哪个类别。画出坐标图,显然是B类。使用kNN编写程序计算,代码如下:

    #!/usr/bin/python
    #coding:gbk
    
    from numpy import *
    import operator
    
    def createDataSet():
        group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
        labels = ['A','A','B','B']
        return group, labels
    
    def classify0(inX, dataSet, labels, k):
        """
        @param inX: 用于分类的输入向量
        @param dataSet:训练样本集
        @param labels:标签向量    len(labels)==len(dataSet)
        @param k:用于选择最近邻居的数目s
        """
        dataSetSize = dataSet.shape[0]      #矩阵的行数
        diffMat = tile(inX, (dataSetSize, 1)) - dataSet
        #tile(inX, (dataSetSize, 1)):构造一个矩阵,行数和dataSet相同,每行都是1个inX构成
        #diffMat就是计算:构造出的矩阵与训练集在每个维度上求差值("距离")  ==> 差值矩阵
        
        sqDiffMat = diffMat ** 2 #矩阵中每个元素都求平方
        
        sqDistances = sqDiffMat.sum(axis=1)  #计算每一行的sigma   axis=0:每一列   axis=1:每一行
        distances = sqDistances ** 0.5
        sortedDistIndicies = distances.argsort()  
        #np.argsort()得到的是排序后的数据原来位置的下标。
        #而distances本身并没有被改变掉
        
        classCount={}
        for i in range(k):
            voteIlabel = labels[sortedDistIndicies[i]]
            classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
            #D.get(k[,d]) -> D[k] if k in D, else d.  d defaults to None.
            #即:classCount.get(voteIlabel,0)在存在voteIlabel这个key的时候,得到classCount[voteIlabel)
            #当不存在这个key的时候,得到0。 也就是一个累加器
            
        sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
        #itermitems():字典类型的数据所拥有的迭代器
        #operator.itemgetter(n):是定义了一个函数,通过该函数作用到对象上才能获取值。  
        #这里itemgetter(1)取得第二列的值,也就是classCount这个字典的value。按照这个value来排序。
        #reverse=True:列表按照降序排列。
        
        return sortedClassCount[0][0]
    
    
    
    if __name__ == '__main__':
        dataSet, labels = createDataSet()
        inX=[0,0]
        k=3
        result=classify0(inX, dataSet, labels, k)
        print result
        
    

    简单实践:改进约会网站的配对效果

    还是书上的例子,不过也算是简单的实际应用了:每一个被邀请约会的人有四个参数(a,b,c,label),前三个是三个维度的数据,最后一个参数是(a,b,c)组合得到的类别。现在有若干(a,b,c,label)数据,需要用kNN算法,对于给定的(a',b',c'),计算label'。

    数据分布

    (a,b,c)数据分布如图所示,每一种颜色代表一个类别。
    (a,b,c)数据分布如图所示

    数据归一化

    比如(a,b,c,label)序列中,a、b、c三者权重相同,而b这一维度数据的变化范围远远大于a、c各自的变化范围,此时算出的向量距离中b的权重过大。一个解决办法是把a、b、c三个维度的数据都映射到同一个变化范围,比如[0,1]

    def autoNorm(dataSet):
        """
        归一化:将一系列的数值转化为[0,1]之间的数据
        """
        minVals = dataSet.min(0)
        maxVals = dataSet.max(0)
        #min(0):将dataSet的列向量逐一求min,再构造为list。也就是:每一列分别求最小值,最后构成一个list
        #min(1):对行向量进行处理
        #类似于scheme中的(flatmap proc seq) = (flatmap min dataSet),也就是map+accumulate的过程
        
        ranges = maxVals - minVals
        normDataSet = zeros((shape(dataSet)))
        m = dataSet.shape[0]   #数据行数,也就是数据条目数
        normDataSet = dataSet - tile(minVals, (m, 1))
        normDataSet = normDataSet / tile(ranges, (m, 1))
        
        return normDataSet, ranges, minVals
    

    小插曲

    autoNorm函数计算的到的normDataSet正确运行结果如下(和书上不一样,感觉书上结果是错的):

    [[ 0.44832535  0.39805139  0.56233353]
     [ 0.15873259  0.34195467  0.98724416]
     [ 0.28542943  0.06892523  0.47449629]
     ..., 
     [ 0.29115949  0.50910294  0.51079493]
     [ 0.52711097  0.43665451  0.4290048 ]
     [ 0.47940793  0.3768091   0.78571804]]
    

    分类结果

    好了,终于可以调用各函数,实现最终的计算了。这里采取前10%的数据作为测试数据,后90%的数据作为训练数据。代码如下:

    def datingClassTest():
        """
        分类器针对约会网站的测试代码
        """
        hoRatio = 0.10
        datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
        normMat, ranges, minVals = autoNorm(datingDataMat)
        m = normMat.shape[0]
        numTestVecs = int(m*hoRatio)
        errorCount = 0.0
        for i in range(numTestVecs):
            classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m,:], datingLabels[numTestVecs:m], 3)
            print "分类器计算出的分类为: %d, 实际分类为 is: %d"%(classifierResult, datingLabels[i])
            
            if classifierResult!=datingLabels[i]:
                errorCount += 1.0
        print "总体错误率为: %f"%(errorCount/float(numTestVecs))
        
    if __name__ == '__main__':
        datingClassTest()
    
    

    运行结果:

    分类器计算出的分类为: 2, 实际分类为 is: 2
    分类器计算出的分类为: 1, 实际分类为 is: 1
    分类器计算出的分类为: 3, 实际分类为 is: 3
    分类器计算出的分类为: 3, 实际分类为 is: 3
    分类器计算出的分类为: 2, 实际分类为 is: 2
    分类器计算出的分类为: 1, 实际分类为 is: 1
    分类器计算出的分类为: 3, 实际分类为 is: 1
    总体错误率为: 0.050000
    

    总结

    通过kNN算法的学习知道,算法实际步骤如下:

    1. 准备数据,包括数据本身和label(类别)标记
    2. 数据整形:读入数据并调整为向量格式
    3. 切分数据块。比如10%为测试数据,90%为训练数据。
    4. 数据权重调整。如果每个维度权重相同,则可以使用归一化方法
    5. 计算测试数据和训练数据之间的距离
    6. 选取距离最小的k个结果,遍历这些结果,统计对应的label频数
    7. 将上一步得到的结果按照频数排序,排序后第一个结果的label就是计算出的label
    8. 每一个测试数据在计算时,可同时统计错误率

    另一个实践:手写识别过程

    这里检测各位数字,即0~9。手写识别过程,经过图像预处理,得到32*32的矩阵表示的数字,其中数字形状区域用1表示,非数字的空白区域用0表示。训练数据和测试数据分属两个文件夹下,每个文件都是只有一个数字的txt文件,文件命名有规律:文件对应的数字+不重复的序号。

    首先是将32*32的二维数组转换为一维的数组,也就是1024维的一个行向量。同时从文件名中取出当前向量对应的类别。这样就完成了训练集的操作。

    对于测试集,用同样的方式先构造1024维的向量,然后每一个都通过kNN的分类器函数classify0()来计算其label(分类),并统计误差。

    def img2vector(filename):
        vector=zeros(1024)
        fr = open(filename)
        for i in xrange(32):
            lineStr = fr.readline()
            for j in xrange(32):
                vector[32*i+j] = int(lineStr[j])
        return vector
    
    def simple_img2vector_test():
        testVector = img2vector('testDigits/0_13.txt')
        print testVector[0:31]
        
    def handwritingClassTest():
        hwLabels=[]
        trainingFileList=listdir('trainingDigits')
        m=len(trainingFileList)
        trainingMat=zeros((m, 1024))
        for i in xrange(m):
            fileNameStr=trainingFileList[i]
            fileStr=fileNameStr.split('.')[0]
            classNumStr = int(fileStr.split('_')[0])
            hwLabels.append(classNumStr)
            trainingMat[i,:]=img2vector('trainingDigits/%s'%fileNameStr)
        testFileList=listdir('testDigits')
        errorCount=0.0
        mTest=len(testFileList)
        for i in xrange(mTest):
            fileNameStr=testFileList[i]
            fileStr=fileNameStr.split('.')[0]
            classNumStr=int(fileStr.split('_')[0])
            vectorUnderTest=img2vector('testDigits/%s'%fileNameStr)
            classifierResult=classify0(vectorUnderTest, trainingMat, hwLabels, 3)
            print "分类器计算出的类别为: %d, 实际类别为: %d" % (classifierResult, classNumStr)
            if classifierResult!=classNumStr:
                errorCount += 1.0
        print "
    计算错误的总次数为: %d"%errorCount
        print "
    总错误率为: %f"%(errorCount/float(mTest))
            
        
    if __name__ == '__main__':
        handwritingClassTest()
    

    运行结果:

    分类器计算出的类别为: 9, 实际类别为: 9
    分类器计算出的类别为: 9, 实际类别为: 9
    分类器计算出的类别为: 9, 实际类别为: 9
    分类器计算出的类别为: 9, 实际类别为: 9
    分类器计算出的类别为: 9, 实际类别为: 9
    分类器计算出的类别为: 9, 实际类别为: 9
    分类器计算出的类别为: 9, 实际类别为: 9
    分类器计算出的类别为: 9, 实际类别为: 9
    
    计算错误的总次数为: 11
    
    总错误率为: 0.011628
    
    Greatness is never a given, it must be earned.
  • 相关阅读:
    Oracle decode函数
    Flink笔记
    httpclient之put 方法(参数为json类型)
    XMLHTTPRequest的理解 及 SpringMvc请求和响应xml数据
    SQL获取本周,上周,本月,上月第一天和最后一天 注:本周从周一到周天
    Other
    Sql根据起止日期生成时间列表
    sql 在not in 子查询有null值情况下经常出现的陷阱
    sql 判断一个表的数据不在另一个表中
    查看系统触发器
  • 原文地址:https://www.cnblogs.com/zjutzz/p/4596124.html
Copyright © 2011-2022 走看看