zoukankan      html  css  js  c++  java
  • 《机器学习实战》中的程序清单2-1 k近邻算法(kNN)classify0都做了什么

    from numpy import *
    import operator
    import matplotlib
    import matplotlib.pyplot as plt
    from imp import *
    #from os import *
    import os
    
    reload(operator)
    
    def start():
        group,labels = createDataSet()
        testSample = [5,7]
        print("测试样本:" ,end="")
        print(testSample)
    
        return classify0(testSample, group, labels, 4)
    
    def createDataSet():
    
        group = array([[1,2],[2,3],[1,1],[4,5]]) #此处随意定义,表示一个已知的已分类的数据集
        labels = ['A','A','B','B']
    
        #例如
        #group = array([[1,2],[2,3],[1,1],[4,5],[5,7],[6,6]]) #此处随意定义,表示一个已知的已分类的数据集
        #labels = ['A','A','B','B','C','C']
    
        return group, labels 
    
    def classify0(inX, dataSet, labels, k):
        """
          inX 是输入的测试样本,是一个[x, y]样式的
          dataset 是训练样本集
          labels 是训练样本标签
          k 是top k最相近的
        """
    
        # 矩阵的shape是个tuple,如果直接调用dataSet.shape,会返回(4,2),即
        # 返回矩阵的(行数,列数),
        # 那么shape[0]获取数据集的行数,
        # 行数就是样本的数量
        # shape[1]返回数据集的列数
        dataSetSize = dataSet.shape[0]
    
        ###################说明代码########################
        #print("dataSet.shape[0]返回矩阵的行数:")
        #print(dataSetSize)
        #print("dataSet.shape[1]返回矩阵的列数:")
        #cols = dataSet.shape[1]
        #print(cols)
        #print(dataSet.shape)
        #print("dataSet.shape类型:")
        #print(type(dataSet.shape))
        ###################################################
    
        #此处Mat是Maxtrix的缩写,diffMat,即矩阵的差,结果也是矩阵
        #关于tile函数的说明,见http://www.cnblogs.com/Sabre/p/7976702.html
        #简单来说就是把inX(本例是[1,1])在“行”这个维度上,复制了dataSetSize次(本例dataSetSize==4),在“列”这个维度上,复制了1次
        #形成[[1,1],[1,1],[1,1],[1,1]]这样一个矩阵,以便与dataSet进行运算
        #之所以进行这样的运算,是因为要使用欧式距离公式求输入点与已存在各点的距离
    
        #这是第1步,求给出点[1,1]与已知4点的差,输出为矩阵
        diffMat = tile(inX,(dataSetSize,1)) - dataSet
        #print(tile(inX,(dataSetSize,1)))
    
        ###################说明代码########################
        #print("diffMat:" + str(diffMat))
        ###################################################
        
        #第2步,对矩阵进行平方,即,求差的平方
        sqDiffMat = diffMat ** 2
    
        ###################说明代码########################
        #print("sqDiffMat:" + str(sqDiffMat))
        #print("sqDiffMat",end="")
        #print(sqDiffMat[324])
        ###################################################
    
        #sum(axis=1)是将矩阵中每一行中的数值相加,如[[0 0] [1 1] [0 1] [9 9]]将得到[0,2,1,18],得到平方和
        #sum(axis=0)是将矩阵中每一列中的数值相加
        #第3步,求和
        sqDistances = sqDiffMat.sum(axis=1)
        #print("sqDistances:", end="")
        #print(sqDistances[875])
    
        ###################说明代码########################
        #print("sqDistances:" + str(sqDistances))
        ###################################################
        
        #第4步,将平方和进行开方,得到距离,输出为数组
        distances = sqDistances ** 0.5
    
    
        ###################说明代码########################
        #print("未知点到各个已知点的距离:",distances)
        ###################################################
    
        #argsort(),将数组中的元素的索引放在由小到大的位置上由小到大排序
        #如数组a = array([ 0 4 3 18]),b = a.argsort()之后,得到b是[0 2 1 3]这是a的索引数组,最小的在最前面,位置0,第二小的是索引为2的元素,即3,3在数组中的位置是2
        #第三小的是索引为1的,即4,4在数组中的索引位置是2,第四小的是索引为3的,即18
        #这样保证了原数组元素的位置不变,以便进行标签的匹配
        #print(distances[875])
        #print(distances[324])
        #print(distances[392])
        sortedDistIndicies = distances.argsort()
    
        ###################说明代码########################
        #print("索引位置:", sortedDistIndicies) #可得到前k个索引
        ###################################################
        
        #创建空字典
        classCount = {} 
        
        #k值是取前k个样本进行比较
        for i in range(k):
            #返回distances中索引为sortedDistIndicies[i]的值
            #此例中分别为:
            #sortedDistIndicies[0]==0,则labels[0]=='A',voteIlabel=='A'
            #sortedDistIndicies[1]==2,则labels[2]=='B',voteIlabel=='B'
            #sortedDistIndicies[2]==1,则labels[0]=='A',voteIlabel=='A'
            #sortedDistIndicies[3]==18,则labels[0]=='B',voteIlabel=='B'
    
            voteIlabel = labels[sortedDistIndicies[i]] 
            #print("中华人民共和国")
            ###################说明代码########################
    
            # print(voteIlabel)
            # print("标签" + str(i) + ":" + str(voteIlabel))
            ###################################################
    
            #dict.get(key, default=None),对于键 key 返回其对应的值,或者若 dict 中不含 key 则返回 default(注意, default的默认值为 None,此处设置为0)
            #第一次调用classCount.get时,classCount内还没有值
            classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
            
            ###################说明代码########################
            # print("第"+str(i+1)+"次访问,classCount[" + str(voteIlabel) + "]值为:" + str(classCount[voteIlabel]))
            # print("classCount的内容为:")
            # print(classCount)
            ###################################################
        
        # sorted(iterable[,cmp,[,key[,reverse=True]]])
        # 作用:Return a new sorted list from the items in iterable.
        #           第一个参数是一个iterable,返回值是一个对iterable中元素进行排序后的列表(list)。
        # 可选的参数有三个,cmp、key和reverse。
        # 1)cmp指定一个定制的比较函数,这个函数接收两个参数(iterable的元素),如果第一个参数小于第二个参数,返回一个负数;如果第一个参数等于第二个参数,返回零;如果第一个参数大于第二个参数,返回一个正数。默认值为None。
        # 2)key指定一个接收一个参数的函数,这个函数用于从每个元素中提取一个用于比较的关键字。默认值为None。
        #   从python2.4开始,list.sort()和sorted()函数增加了key参数来指定一个函数,此函数将在每个元素比较前被调用
        #   key参数的值为一个函数,此函数只有一个参数且返回一个值用来进行比较。这个技术是快速的,因为key指定的函数将准确地对每个元素调用。
        #   key=operator.itemgetter(0)或key=operator.itemgetter(1),决定以字典的键排序还是以字典的值排序
        #   0以键排序,1以值排序
        # 3)reverse是一个布尔值。如果设置为True,列表元素将被倒序排列。
        # operator.itemgetter(1)这个很难解释,用以下的例子一看就懂
        # a=[11,22,33]
        # b = operator.itemgetter(2)
        # b(a)
        # 输出:33
        # b = operator.itemgetter(2,0,1)
        # b(a)
        # 输出:(33,11,22)
        # operator.itemgetter函数返回的不是值,而是一个函数,通过该函数作用到对象上才能获取值
       # 在这里
    itemgetter(1)的作用是按照第二个元素的顺序对元组进行排序,也就是value的顺序,如果改成itemgetter(0),则根据Key值排序
        sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1), reverse=True)
    
        #print(sortedClassCount)
    
        #返回正序排序后最小的值,即“k个最小相邻”的值决定测试样本的类别
        print("最终结果,测试样本类别:" , end="")
        print(sortedClassCount)
        return sortedClassCount[0][0] 

    以下为输出结果,未必完全一致,请自行调试。

    输出结果:

    dataSet.shape[0]返回矩阵的行数:
    4
    dataSet.shape[1]返回矩阵的列数:
    2
    (4, 2)
    dataSet.shape类型:
    <class 'tuple'>
    diffMat:[[ 2 1]
    [ 1 0]
    [ 2 2]
    [-1 -2]]
    sqDiffMat:[[4 1]
    [1 0]
    [4 4]
    [1 4]]
    sqDistances:[5 1 8 5]
    未知点到各个已知点的距离: [ 2.23606798 1. 2.82842712 2.23606798]
    索引位置: [1 0 3 2]
    标签0:A
    第1次访问,classCount[A]值为:1
    classCount的内容为:
    {'A': 1}
    标签1:A
    第2次访问,classCount[A]值为:2
    classCount的内容为:
    {'A': 2}
    标签2:B
    第3次访问,classCount[B]值为:1
    classCount的内容为:
    {'A': 2, 'B': 1}
    标签3:B
    第4次访问,classCount[B]值为:2
    classCount的内容为:
    {'A': 2, 'B': 2}
    [('A', 2), ('B', 2)]
    最终结果,测试样本类别:A
    [Finished in 5.3s]
  • 相关阅读:
    PLSQL过程创建和调用
    约束定义及相关用法
    序列和索引
    控制用户访问
    ORACLE常用数据字典
    管理对象与数据字典
    Oracle enterprise linux系统的安装以及ORACLE12C的安装
    SUSE12的虚拟机安装以及ORACLE12C的安装
    PLSQL developer开发工具相关配置
    设计模式之六则并进
  • 原文地址:https://www.cnblogs.com/Sabre/p/8359017.html
Copyright © 2011-2022 走看看