zoukankan      html  css  js  c++  java
  • 机器学习笔记(一)之k近邻算法(k-Nearest Neighbor)

    k近邻算法(kNN)是监督学习的一种。其原理非常简单:存在一个样本数据集,也称作训练样本集。样本集中的每个数据都存在标签,即知道数据与对应分类的关系。输入新的没有标签的数据,将新的数据的每个特征与样本中的数据特征进行对比,然后利用算法提取出样本集中特征最相似的数据(最邻近)分类标签。一般来说我们只选取样本集中前k个最相似的数据。

    k近邻算法一般流程:

    1.选择一种距离计算方式,通过数据所有的特征计算新的数据与已知数据集中的距离

    2.按照距离递增的顺序进行排序,选取与当前距离最近的k个点

    3.返回k个点出现频率最多的类作为预测类。如果是回归问题,则需要加上权值。

    需要注意的是kNN是不需要训练的。

    kNN算法的关键

    1.k的选择

    举一个简单的例子,绿色圆表示要被赋值的点,是红色三角形还是蓝色四边形?如果k为3,则红色三角形出现的比例是2/3,所以预测结果为红色三角形。若k=5,则蓝色四边形出现的比例为3/5,所以预测结果为蓝色四边形。

                                                                         

    因此k的选择对于结果的影响非常大。如果k值太小,受到噪声干扰比较明显,容易受到出现过拟合现象。而k值过大则会导致其分界不明显。一种选择K值得方法是使用 cross-validate(交叉验证)误差统计选择法。也就是将数据样本的一部分作为训练样本,一部分作为测试样本。一般来说选择90%的作为训练数据集,剩下的作为测试集。选择不同的k值计算误差,最后选出误差最小的k值。

    2.需要对数据所有特征做可比较量化

    如果数据中存在非数值类型,必须对其进行数值化。例如样本中包含颜色(红绿蓝)。颜色本身时没有距离的,但是我们可以将其转换成灰度值在进行计算。另外,样本有多个参数(特征),每个参数都有自己的定义域和取值范围。它们对于距离计算的影响也就不一样。为了公平起见,我们必须将特征进行归一化。


    下面是python的实现代码

    def createDataSet():
        group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
        labels = ['A','A','B','B']
        return group, labels

    创建数据和对应的标签

    def classify0(inX, dataSet, labels, k):
        dataSetSize=dataSet.shape[0]#返回dataset的第一维的长度
        diffMat = tile(inX, (dataSetSize,1)) - dataSet#计算各个点到原点的x轴,y轴的距离。
        #计算出各点离原点的距离
        #表示diffMat的平方
        sqDiffMat = diffMat**2#平方只针对数组有效
        sqDistances=sqDiffMat.sum(axis = 1)
        distances=sqDistances**0.5
        sortedDistIndices = distances.argsort()#返回从小到大的引索
        classCount = {}
        for i in range(k):
            voteLabel = labels[sortedDistIndices[i]]#找到对应的从小到大的标签
            classCount[voteLabel] = classCount.get(voteLabel,0)+1
            print(classCount.get(voteLabel,0)+1)
            print(classCount)
            sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]

    调用函数

    group,labels = createDataSet()
        classer=classify0([0,0],group,labels,3)
    group表示的是一个4*2的数组,也就是说dataSet是一个4*2的矩阵。dataSet.shape[0]返回的值为4。tile表示将数组进行复制。然后计算出各个点距离原点的欧氏距离。最后开根号排序,选择距离最近的k各点。

    第二步:将数据从文本中导入,并将文本记录转换成Numpy的解析程序。

    def file2matrix(filename):
        fr=open(filename)#打开文件
        arrayOLines=fr.readlines()#读取所有行的数据,直到遇到结束符
        numberOfLines=len(arrayOLines)
        returnMat=zeros((numberOfLines,3))
        classLabelVector=[]
        index = 0
        for lines in arrayOLines:
            lines = lines.strip()#截取掉后面的换行符
            listFromLine = lines.split('	')#
            returnMat[index,:]=listFromLine[0:3]
            classLabelVector.append(int(listFromLine[-1]))
            index += 1
        return returnMat,classLabelVector
    第三部:分析数据,将数据用散点图表示出来

    def show(datingDataMat,datingLabels):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2],15.0*array(datingLabels),15.0*array(datingLabels))
        plt.show() 


    最后利用kNN实现手写识别程序的完整代码

    from numpy import *
    import operator
    import matplotlib
    import matplotlib.pyplot as plt
    from os import listdir
    
    def classify0(inX, dataSet, labels, k):
        dataSetSize=dataSet.shape[0]#返回dataset的第一维的长度
        print(dataSetSize)
        diffMat = tile(inX, (dataSetSize,1)) - dataSet
        #计算出各点离原点的距离
        #表示diffMat的平方
        sqDiffMat = diffMat**2#平方只针对数组有效
        sqDistances=sqDiffMat.sum(axis = 1)
        distances=sqDistances**0.5
        sortedDistIndices = distances.argsort()#返回从小到大的引索
        classCount = {}
        for i in range(k):
            voteLabel = labels[sortedDistIndices[i]]#找到对应的从小到大的标签
            classCount[voteLabel] = classCount.get(voteLabel,0)+1
            sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]
    
    def createDataSet():
        group=array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])#numpy里面的数组,注意和list的区别
        labels=['A','A','B','B']
        return group,labels
    
    def file2matrix(filename):
        fr=open(filename)
        arrayOLines=fr.readlines()
        numberOfLines=len(arrayOLines)
        print(numberOfLines)
        returnMat=zeros((numberOfLines,3))
        classLabelVector=[]
        index = 0
        for lines in arrayOLines:
            lines = lines.strip()
            listFromLine = lines.split('	')
            returnMat[index,:]=listFromLine[0:3]
            classLabelVector.append(int(listFromLine[-1]))
            index += 1
        return returnMat,classLabelVector
    
    def show(datingDataMat,datingLabels):
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2],15.0*array(datingLabels),15.0*array(datingLabels))
        plt.show()
    
    
    def autoNorm(dataSet):#将特征值归一化
        minVals=dataSet.min(0)#选择数据集中最小的
        maxVals=dataSet.max(0)
        ranges = maxVals - minVals
        normDataSet=zeros(shape(dataSet))
        m = dataSet.shape[0]
        normDataSet = dataSet-tile(minVals,(m,1))
        normDataSet = normDataSet/tile(ranges,(m,1))
        return normDataSet,ranges,minVals
    
    def datingClassTest():
        hoRatio = 0.50  # hold out 10%
        datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')  # load data setfrom file
        normMat, ranges, minVals = autoNorm(datingDataMat)
        m = normMat.shape[0]
        numTestVecs = int(m * hoRatio)
        errorCount = 0.0
        for i in range(numTestVecs):
            classifierResult = classify0(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
            print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
            if (classifierResult != datingLabels[i]):
                errorCount += 1.0
                print( "the total error rate is: %f" % (errorCount / float(numTestVecs)))
               # print(errorCount)
    
    def img2vector(filename):
        returnVect = zeros((1, 1024))
        fr = open(filename)
        for i in range(32):
            lineStr = fr.readline()
            for j in range(32):
                returnVect[0, 32 * i + j] = int(lineStr[j])
        return returnVect
    
    def handwritingClassTest():
        hwLabels = []
        trainingFileList = listdir('trainingDigits')  # load the training set
        m = len(trainingFileList)
        trainingMat = zeros((m, 1024))
        for i in range(m):
            fileNameStr = trainingFileList[i]
            fileStr = fileNameStr.split('.')[0]  # take off .txt
            classNumStr = int(fileStr.split('_')[0])
            hwLabels.append(classNumStr)
            trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)
        testFileList = listdir('testDigits')  # iterate through the test set
        errorCount = 0.0
        mTest = len(testFileList)
        for i in range(mTest):
            fileNameStr = testFileList[i]
            fileStr = fileNameStr.split('.')[0]  # take off .txt
            classNumStr = int(fileStr.split('_')[0])
            vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
            classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
            print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
            if (classifierResult != classNumStr): errorCount += 1.0
        print("
    the total number of errors is: %d" % errorCount)
        print("
    the total error rate is: %f" % (errorCount / float(mTest)))
    
    
    
    if __name__ == "__main__":
        group,labels = createDataSet()
        classer=classify0([0,0],group,labels,3)
      #  handwritingClassTest()
        datingDataMat, datingLabels=file2matrix('datingTestSet2.txt')
        show(datingDataMat,datingLabels)
    
    
    
    
    
    




  • 相关阅读:
    error occurred during initialization of vm
    Service Discovery protocol(SDP)
    nRF51822EK_PRO
    Binder
    android samsung note3  device not found
    BLE pairing vs. bonding
    remote link Centos6.6 Horrible Slow
    depmod -a
    start and end call use itelephony and how to pick up a call
    WebService
  • 原文地址:https://www.cnblogs.com/gaot/p/7709691.html
Copyright © 2011-2022 走看看