zoukankan      html  css  js  c++  java
  • 机器学习 | 算法笔记- k近邻(KNN)

    前言

    本系列为机器学习算法的总结和归纳,目的为了清晰阐述算法原理,同时附带上手代码实例,便于理解。

    目录

      决策树
      组合算法(Ensemble Method)
      K-Means
      机器学习算法总结
     
    本章为k近邻算法,内容包括模型介绍及代码实现(包括自主实现和sklearn案例)。

    一、算法简介

    1.1 基本概念

    k近邻法(k-nearest neighbor, k-NN)是1967年由Cover T和Hart P提出的一种基本分类与回归方法。
    基本概念如下:存在一个样本数据集合,所有特征属性已知,并且样本集中每个对象都已知所属分类。对不知道分类的待测对象,将待测对象的每个特征属性与样本集中数据对应的特征属性进行比较,然后算法提取样本最相似对象(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的对象数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。最后根据这k个数据的特征和属性,判断待测数据的分类

    1.2 K近邻的三个基本要素

      1) k值的选取。在应用中,k值一般选择一个比较小的值,一般选用交叉验证来取最优的k值
      2)距离度量。Lp距离:误差绝对值p次方求和再求p次根。欧式距离:p=2的Lp距离。曼哈顿距离:p=1的Lp距离。p为无穷大时,Lp距离为各个维度上距离的最大值
      3)分类决策规则。也就是如何根据k个最近邻决定待测对象的分类。k最近邻的分类决策规则一般选用多数表决

    1.3 KNN基本执行步骤

      1)计算待测对象和训练集中每个样本点的欧式距离
      2)对上面的所有距离值排序
      3)选出k个最小距离的样本作为“选民”
      4)根据“选民”预测待测样本的分类或值

    1.4 KNN特点

      1)原理简单
      2)保存模型需要保存所有样本集
      3)训练过程很快,预测速度很慢
      · 优点:
      简单好用,容易理解,精度高,理论成熟,既可以用来做分类也可以用来做回归;
      可用于非线性分类;
      可用于数值型数据和离散型数据(既可以用来估值,又可以用来分类)
      训练时间复杂度为O(n);无数据输入假定;
      对异常值不敏感。
      准确度高,对数据没有假设,对outlier不敏感;
      · 缺点:
      计算复杂性高;空间复杂性高;需要大量的内存
      样本不平衡问题(即有些类别的样本数量很多,而其它样本的数量很少);
      一般数值很大的时候不用这个,计算量太大。但是单个样本又不能太少,否则容易发生误分。
      最大的缺点是无法给出数据的内在含义。
     
    需要思考的问题:
    样本属性如何选择?如何计算两个对象间距离?当样本各属性的类型和尺度不同时如何处理?各属性不同重要程度如何处理?模型的好坏如何评估?

    二、代码实现

    K近邻算法的一般流程:收集数据- 准备数据- 分析数据- 测试算法- 使用算法

    2.1 python3代码实现

    2.1.1
    首先以电影分类为例,了解kNN工作流程。主要包括创建数据,迭代计算两点公式。代码如下
    # -*- coding: UTF-8 -*-
    import numpy as np
    import operator
    import collections
    
    """
    函数说明:创建数据集
    
    Parameters:
        无
    Returns:
        group - 数据集
        labels - 分类标签
    
    """
    def createDataSet():
        #四组二维特征
        group = np.array([[1,101],[5,89],[108,5],[115,8]])
        #四组特征的标签
        labels = ['爱情片','爱情片','动作片','动作片']
        return group, labels
    
    """
    函数说明:kNN算法,分类器
    
    Parameters:
        inX - 用于分类的数据(测试集)
        dataSet - 用于训练的数据(训练集)
        labes - 分类标签
        k - kNN算法参数,选择距离最小的k个点
    Returns:
        sortedClassCount[0][0] - 分类结果
    
    """
    def classify0(inx, dataset, labels, k):
        # 计算距离
        dist = np.sum((inx - dataset)**2, axis=1)**0.5
        # k个最近的标签
        k_labels = [labels[index] for index in dist.argsort()[0 : k]]
        # 出现次数最多的标签即为最终类别
        label = collections.Counter(k_labels).most_common(1)[0][0]
        return label
    
    if __name__ == '__main__':
        #创建数据集
        group, labels = createDataSet()
        #测试集
        test = [101,20]
        #kNN分类
        test_class = classify0(test, group, labels, 3)
        #打印分类结果
        print(test_class)
    View Code
    2.1.2
    以K近邻算法实现约会网站配对效果判定。
      1)下载数据集 datingTestSet.txt
      2)准备数据:数据解析
          将数据分为特征矩阵和对应的分类标签矩阵。
    # -*- coding: UTF-8 -*-
    import numpy as np
    """
    函数说明:打开并解析文件,对数据进行分类:1代表不喜欢,2代表魅力一般,3代表极具魅力
    
    Parameters:
        filename - 文件名
    Returns:
        returnMat - 特征矩阵
        classLabelVector - 分类Label向量
    
    """
    def file2matrix(filename):
        #打开文件
        fr = open(filename)
        #读取文件所有内容
        arrayOLines = fr.readlines()
        #得到文件行数
        numberOfLines = len(arrayOLines)
        #返回的NumPy矩阵,解析完成的数据:numberOfLines行,3列
        returnMat = np.zeros((numberOfLines,3))
        #返回的分类标签向量
        classLabelVector = []
        #行的索引值
        index = 0
        for line in arrayOLines:
            #s.strip(rm),当rm空时,默认删除空白符(包括'
    ','
    ','	',' ')
            line = line.strip()
            #使用s.split(str="",num=string,cout(str))将字符串根据'	'分隔符进行切片。
            listFromLine = line.split('	')
            #将数据前三列提取出来,存放到returnMat的NumPy矩阵中,也就是特征矩阵
            returnMat[index,:] = listFromLine[0:3]
            #根据文本中标记的喜欢的程度进行分类,1代表不喜欢,2代表魅力一般,3代表极具魅力
            if listFromLine[-1] == 'didntLike':
                classLabelVector.append(1)
            elif listFromLine[-1] == 'smallDoses':
                classLabelVector.append(2)
            elif listFromLine[-1] == 'largeDoses':
                classLabelVector.append(3)
            index += 1
        return returnMat, classLabelVector
    
    """
    函数说明:main函数
    Parameters:
        无
    Returns:
        无
    """
    if __name__ == '__main__':
        #打开的文件名
        filename = "datingTestSet.txt"
        #打开并处理数据
        datingDataMat, datingLabels = file2matrix(filename)
        print(datingDataMat)
        print(datingLabels)
    View Code
      3)分析数据:数据可视化
      直观的发现数据的规律
    """
    函数说明:可视化数据
    
    Parameters:
        datingDataMat - 特征矩阵
        datingLabels - 分类Label
    Returns:
        无
    
    """
    def showdatas(datingDataMat, datingLabels):
        #设置汉字格式
        font = FontProperties(fname=r"c:windowsfontssimsun.ttc", size=14)
        #将fig画布分隔成1行1列,不共享x轴和y轴,fig画布的大小为(13,8)
        #当nrow=2,nclos=2时,代表fig画布被分为四个区域,axs[0][0]表示第一行第一个区域
        fig, axs = plt.subplots(nrows=2, ncols=2,sharex=False, sharey=False, figsize=(13,8))
    
        numberOfLabels = len(datingLabels)
        LabelsColors = []
        for i in datingLabels:
            if i == 1:
                LabelsColors.append('black')
            if i == 2:
                LabelsColors.append('orange')
            if i == 3:
                LabelsColors.append('red')
        #画出散点图,以datingDataMat矩阵的第一(飞行常客例程)、第二列(玩游戏)数据画散点数据,散点大小为15,透明度为0.5
        axs[0][0].scatter(x=datingDataMat[:,0], y=datingDataMat[:,1], color=LabelsColors,s=15, alpha=.5)
        #设置标题,x轴label,y轴label
        axs0_title_text = axs[0][0].set_title(u'每年获得的飞行常客里程数与玩视频游戏所消耗时间占比',FontProperties=font)
        axs0_xlabel_text = axs[0][0].set_xlabel(u'每年获得的飞行常客里程数',FontProperties=font)
        axs0_ylabel_text = axs[0][0].set_ylabel(u'玩视频游戏所消耗时间占',FontProperties=font)
        plt.setp(axs0_title_text, size=9, weight='bold', color='red') 
        plt.setp(axs0_xlabel_text, size=7, weight='bold', color='black') 
        plt.setp(axs0_ylabel_text, size=7, weight='bold', color='black')
    
        #画出散点图,以datingDataMat矩阵的第一(飞行常客例程)、第三列(冰激凌)数据画散点数据,散点大小为15,透明度为0.5
        axs[0][1].scatter(x=datingDataMat[:,0], y=datingDataMat[:,2], color=LabelsColors,s=15, alpha=.5)
        #设置标题,x轴label,y轴label
        axs1_title_text = axs[0][1].set_title(u'每年获得的飞行常客里程数与每周消费的冰激淋公升数',FontProperties=font)
        axs1_xlabel_text = axs[0][1].set_xlabel(u'每年获得的飞行常客里程数',FontProperties=font)
        axs1_ylabel_text = axs[0][1].set_ylabel(u'每周消费的冰激淋公升数',FontProperties=font)
        plt.setp(axs1_title_text, size=9, weight='bold', color='red') 
        plt.setp(axs1_xlabel_text, size=7, weight='bold', color='black') 
        plt.setp(axs1_ylabel_text, size=7, weight='bold', color='black')
    
        #画出散点图,以datingDataMat矩阵的第二(玩游戏)、第三列(冰激凌)数据画散点数据,散点大小为15,透明度为0.5
        axs[1][0].scatter(x=datingDataMat[:,1], y=datingDataMat[:,2], color=LabelsColors,s=15, alpha=.5)
        #设置标题,x轴label,y轴label
        axs2_title_text = axs[1][0].set_title(u'玩视频游戏所消耗时间占比与每周消费的冰激淋公升数',FontProperties=font)
        axs2_xlabel_text = axs[1][0].set_xlabel(u'玩视频游戏所消耗时间占比',FontProperties=font)
        axs2_ylabel_text = axs[1][0].set_ylabel(u'每周消费的冰激淋公升数',FontProperties=font)
        plt.setp(axs2_title_text, size=9, weight='bold', color='red') 
        plt.setp(axs2_xlabel_text, size=7, weight='bold', color='black') 
        plt.setp(axs2_ylabel_text, size=7, weight='bold', color='black')
        #设置图例
        didntLike = mlines.Line2D([], [], color='black', marker='.',
                          markersize=6, label='didntLike')
        smallDoses = mlines.Line2D([], [], color='orange', marker='.',
                          markersize=6, label='smallDoses')
        largeDoses = mlines.Line2D([], [], color='red', marker='.',
                          markersize=6, label='largeDoses')
        #添加图例
        axs[0][0].legend(handles=[didntLike,smallDoses,largeDoses])
        axs[0][1].legend(handles=[didntLike,smallDoses,largeDoses])
        axs[1][0].legend(handles=[didntLike,smallDoses,largeDoses])
        #显示图片
        plt.show()
    View Code
      4)数据准备:数据归一化
      使用autoNorm函数自动将数据归一化
    """
    函数说明:对数据进行归一化
     
    Parameters:
        dataSet - 特征矩阵
    Returns:
        normDataSet - 归一化后的特征矩阵
        ranges - 数据范围
        minVals - 数据最小值
     
    """
    def autoNorm(dataSet):
        #获得数据的最小值
        minVals = dataSet.min(0)
        maxVals = dataSet.max(0)
        #最大值和最小值的范围
        ranges = maxVals - minVals
        #shape(dataSet)返回dataSet的矩阵行列数
        normDataSet = np.zeros(np.shape(dataSet))
        #返回dataSet的行数
        m = dataSet.shape[0]
        #原始值减去最小值
        normDataSet = dataSet - np.tile(minVals, (m, 1))
        #除以最大和最小值的差,得到归一化数据
        normDataSet = normDataSet / np.tile(ranges, (m, 1))
        #返回归一化数据结果,数据范围,最小值
        return normDataSet, ranges, minVals
    View Code
      5)构建、验证分类器
      将数据分为90%样本集和10%的测试机(可以调整)
    # -*- coding: UTF-8 -*-
    import numpy as np
    import operator
    
    """
    函数说明:kNN算法,分类器
    
    Parameters:
        inX - 用于分类的数据(测试集)
        dataSet - 用于训练的数据(训练集)
        labes - 分类标签
        k - kNN算法参数,选择距离最小的k个点
    Returns:
        sortedClassCount[0][0] - 分类结果
    
    """
    def classify0(inX, dataSet, labels, k):
        #numpy函数shape[0]返回dataSet的行数
        dataSetSize = dataSet.shape[0]
        #在列向量方向上重复inX共1次(横向),行向量方向上重复inX共dataSetSize次(纵向)
        diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
        #二维特征相减后平方
        sqDiffMat = diffMat**2
        #sum()所有元素相加,sum(0)列相加,sum(1)行相加
        sqDistances = sqDiffMat.sum(axis=1)
        #开方,计算出距离
        distances = sqDistances**0.5
        #返回distances中元素从小到大排序后的索引值
        sortedDistIndices = distances.argsort()
        #定一个记录类别次数的字典
        classCount = {}
        for i in range(k):
            #取出前k个元素的类别
            voteIlabel = labels[sortedDistIndices[i]]
            #dict.get(key,default=None),字典的get()方法,返回指定键的值,如果值不在字典中返回默认值。
            #计算类别次数
            classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
        #python3中用items()替换python2中的iteritems()
        #key=operator.itemgetter(1)根据字典的值进行排序
        #key=operator.itemgetter(0)根据字典的键进行排序
        #reverse降序排序字典
        sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
        #返回次数最多的类别,即所要分类的类别
        return sortedClassCount[0][0]
    
    """
    函数说明:分类器测试函数
    
    Parameters:
        无
    Returns:
        normDataSet - 归一化后的特征矩阵
        ranges - 数据范围
        minVals - 数据最小值
    
    """
    def datingClassTest():
        #打开的文件名
        filename = "datingTestSet.txt"
        #将返回的特征矩阵和分类向量分别存储到datingDataMat和datingLabels中
        datingDataMat, datingLabels = file2matrix(filename)
        #取所有数据的百分之十
        hoRatio = 0.10
        #数据归一化,返回归一化后的矩阵,数据范围,数据最小值
        normMat, ranges, minVals = autoNorm(datingDataMat)
        #获得normMat的行数
        m = normMat.shape[0]
        #百分之十的测试数据的个数
        numTestVecs = int(m * hoRatio)
        #分类错误计数
        errorCount = 0.0
    
        for i in range(numTestVecs):
            #前numTestVecs个数据作为测试集,后m-numTestVecs个数据作为训练集
            classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m,:],
                datingLabels[numTestVecs:m], 4)
            print("分类结果:%d	真实类别:%d" % (classifierResult, datingLabels[i]))
            if classifierResult != datingLabels[i]:
                errorCount += 1.0
        print("错误率:%f%%" %(errorCount/float(numTestVecs)*100))
    View Code
      6)使用算法:构建完整可用系统
    """
    函数说明:通过输入一个人的三维特征,进行分类输出
    
    Parameters:
        无
    Returns:
        无
    """
    def classifyPerson():
        #输出结果
        resultList = ['讨厌','有些喜欢','非常喜欢']
        #三维特征用户输入
        precentTats = float(input("玩视频游戏所耗时间百分比:"))
        ffMiles = float(input("每年获得的飞行常客里程数:"))
        iceCream = float(input("每周消费的冰激淋公升数:"))
        #打开的文件名
        filename = "datingTestSet.txt"
        #打开并处理数据
        datingDataMat, datingLabels = file2matrix(filename)
        #训练集归一化
        normMat, ranges, minVals = autoNorm(datingDataMat)
        #生成NumPy数组,测试集
        inArr = np.array([ffMiles, precentTats, iceCream])
        #测试集归一化
        norminArr = (inArr - minVals) / ranges
        #返回分类结果
        classifierResult = classify0(norminArr, normMat, datingLabels, 3)
        #打印结果
        print("你可能%s这个人" % (resultList[classifierResult-1]))
    View Code
    在cmd中,运行程序,并输入数据(12,44000,0.5),预测结果是"你可能有些喜欢这个人",也就是这个人魅力一般。一共有三个档次:讨厌、有些喜欢、非常喜欢,对应着不喜欢的人、魅力一般的人、极具魅力的人。
    本部分完整代码请见:

    2.2 sklearn包实现

    关于sklearn的详细介绍,请见之前的博客 https://www.cnblogs.com/geo-will/p/9512578.html

    2.2.1 

    sklearn实现k-近邻算法简介 官方文档

    2.2.2 KNeighborsClassifier函数8个参数

      - n_neighbors:k值,选取最近的k个点,默认为5。
      - weights:默认是uniform,参数可以是uniform(均等权重)、distance(按距离分配权重),也可以是用户自己定义的函数。uniform是均等的权重,就说所有的邻近点的权重都是相等的。
      - algorithm:快速k近邻搜索算法,默认参数为auto。除此之外,用户也可以自己指定搜索算法ball_tree、kd_tree、brute方法进行搜索。
      - leaf_size:默认是30,这个是构造的kd树和ball树的大小。这个值的设置会影响树构建的速度和搜索速度,同样也影响着存储树所需的内存大小。需要根据问题的性质选择最优的大小。
      - metric:用于距离度量,默认度量是minkowski,也就是p=2的欧氏距离(欧几里德度量)。
      - p:距离度量公式。欧氏距离和曼哈顿距离。这个参数默认为2,也可以设置为1。
      - metric_params:距离公式的其他关键参数,这个可以不管,使用默认的None即可。
      - n_jobs:并行处理设置。默认为1,临近点搜索并行工作数。如果为-1,那么CPU的所有cores都用于并行工作。

    2.2.3 实例

    基于sklearn实现手写数字识别系统
    # -*- coding: UTF-8 -*-
    import numpy as np
    import operator
    from os import listdir
    from sklearn.neighbors import KNeighborsClassifier as kNN
    
    """
    函数说明:将32x32的二进制图像转换为1x1024向量。
    
    Parameters:
        filename - 文件名
    Returns:
        returnVect - 返回的二进制图像的1x1024向量
    
    """
    def img2vector(filename):
        #创建1x1024零向量
        returnVect = np.zeros((1, 1024))
        #打开文件
        fr = open(filename)
        #按行读取
        for i in range(32):
            #读一行数据
            lineStr = fr.readline()
            #每一行的前32个元素依次添加到returnVect中
            for j in range(32):
                returnVect[0, 32*i+j] = int(lineStr[j])
        #返回转换后的1x1024向量
        return returnVect
    
    """
    函数说明:手写数字分类测试
    
    Parameters:
        无
    Returns:
        无
    
    """
    def handwritingClassTest():
        #测试集的Labels
        hwLabels = []
        #返回trainingDigits目录下的文件名
        trainingFileList = listdir('trainingDigits')
        #返回文件夹下文件的个数
        m = len(trainingFileList)
        #初始化训练的Mat矩阵,测试集
        trainingMat = np.zeros((m, 1024))
        #从文件名中解析出训练集的类别
        for i in range(m):
            #获得文件的名字
            fileNameStr = trainingFileList[i]
            #获得分类的数字
            classNumber = int(fileNameStr.split('_')[0])
            #将获得的类别添加到hwLabels中
            hwLabels.append(classNumber)
            #将每一个文件的1x1024数据存储到trainingMat矩阵中
            trainingMat[i,:] = img2vector('trainingDigits/%s' % (fileNameStr))
        #构建kNN分类器
        neigh = kNN(n_neighbors = 3, algorithm = 'auto')
        #拟合模型, trainingMat为训练矩阵,hwLabels为对应的标签
        neigh.fit(trainingMat, hwLabels)
        #返回testDigits目录下的文件列表
        testFileList = listdir('testDigits')
        #错误检测计数
        errorCount = 0.0
        #测试数据的数量
        mTest = len(testFileList)
        #从文件中解析出测试集的类别并进行分类测试
        for i in range(mTest):
            #获得文件的名字
            fileNameStr = testFileList[i]
            #获得分类的数字
            classNumber = int(fileNameStr.split('_')[0])
            #获得测试集的1x1024向量,用于训练
            vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))
            #获得预测结果
            # classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
            classifierResult = neigh.predict(vectorUnderTest)
            print("分类返回结果为%d	真实结果为%d" % (classifierResult, classNumber))
            if(classifierResult != classNumber):
                errorCount += 1.0
        print("总共错了%d个数据
    错误率为%f%%" % (errorCount, errorCount/mTest * 100))
    
    
    """
    函数说明:main函数
    
    Parameters:
        无
    Returns:
        无
    
    """
    if __name__ == '__main__':
        handwritingClassTest()
    View Code
    可以尝试更改这些参数的设置,加深对其函数的理解。
     
     
    参考:
     

  • 相关阅读:
    项目的搭建步骤:
    MySQL复习笔记记录
    记录搭建SSM框架中常用到的功能:(监听器)、过滤器和拦截器 以及相关的拓展内容的学习记录
    为什么要配置spring**.xml或者applicationContext.xml --学习笔记
    java8 Lambda表达式学习笔记
    java多线程高并发学习从零开始——初识volatile关键字
    java多线程高并发学习从零开始——新建线程
    JVM工作机制以及异常处理之内存溢出OOM(OutOfMemoryError)/SOF(StackOverflowError)--Java学习记录2——更新中
    记录使用idea构建出现错误:failed to execute goal org.apache.maven.plugins:maven-javadoc-plugin:2.9.1:jar——Java学习记录3
    Oracle创建用户和表空间的步骤 和 导入dmp文件的方法 —— 数据库学习
  • 原文地址:https://www.cnblogs.com/geo-will/p/9771528.html
Copyright © 2011-2022 走看看