zoukankan      html  css  js  c++  java
  • KNN分类算法

    K最近邻(KNN,K-NearestNeighbor)是1967年由Cover T和Hart P提出的一种基本分类与回归方法,它是数据挖掘分类技术中最简单的方法之一,非常容易理解应用。所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的(一般用距离最短表示最接近)K个邻居来代表。如果K个邻居里大多数都属于某一个类别,那么该样本也被划分为这个类别。

    KNN算法中所选择的邻居都是已经正确分类的对象,属于懒惰学习,即KNN没有显式的学习过程,没有训练阶段。待收到新样本后直接进行处理。
     
    算法描述:
    1)计算测试数据与各个训练数据之间的距离;
    2)按照距离的递增关系进行排序;
    3)选取距离最小的K个点;
    4)确定前K个点所在类别的出现频率;
    5)返回前K个点中出现频率最高的类别作为测试数据的预测分类.
     
    通常k是不大于20的整数,上限是训练数据集数量n的开方,随着数据集的增大,K的值也要增大。
    依赖于训练数据集和K的取值,输出结果可能会有不同。所以需要评估算法的正确率,选取产生最小误差率的K:比如我们可以提供已有数据的90%作为训练样本来训练分类器,而使用其余的10%数据去测试分类器,检测错误率是否随着K值的变化而减小。需要注意的是,10%的测试数据应该是随机选择的。
     
     
    python示例1(sklearn包封装了KNN算法)
    import numpy as np
    from sklearn import neighbors        
    from sklearn import datasets
     
    knn = neighbors.KNeighborsClassifier()
    iris = datasets.load_iris() #加载"Anaconda2Libsite-packagessklearndatasetsdatairis.csv"
    #print iris
     
    #print iris.data
    #print iris.target
    knn.fit(iris.data, iris.target) #用KNN分类器进行建模,这里用的默认的参数
    predictedLabel = knn.predict([[5.1, 5.2, 5.3, 5.4]])
    print ("predictedLabel is :" + str(predictedLabel))
    predictedLabel = knn.predict([[0.1, 0.2, 0.3, 0.4]])
    print ("predictedLabel is :" + str(predictedLabel))
     
    #拆分测试集和训练集,计算算法的准确率
    np.random.seed(0)
    #permutation随机生成50个训练数,测试集就是sum-50
    indices = np.random.permutation(len(iris_y))
    iris_X_train = iris_X[indices[:50]]
    iris_y_train = iris_y[indices[:50]]
    iris_X_test = iris_X[indices[50:]]
    iris_y_test = iris_y[indices[50:]]
    #print iris_X_train 
    #用KNN分类器进行建模,和上面相比,不再使用所有数据,而是训练集数据
    knn.fit(iris_X_train, iris_y_train)
    #预测所有测试集数据的鸢尾花类型
    predict_result = knn.predict(iris_X_test)
    #计算预测的准确率
    print(knn.score(iris_X_test, iris_y_test))
     
     
    python示例2(自己写KNN算法,简单示例)
    import numpy as np
    import math
    a=[]
    p=np.array([0,4,5])
    p1=np.array([3,0,0])
    p2=np.array([3,1,2])
    p3=np.array([3,4,5])
    labels=['red','blue','yellow']#p1/p2/p3分别对应的label
    p11=p1-p
    p21=p2-p
    p31=p3-p
     
    #math.hypot用来计算两点之间距离((x1-x2)^2+(y1-y2)^2)^0.5
    #因为是三维数组,所以需要计算三点之间距离,math.hypot只能传参2个,所以计算2次达到目的
    p111=math.hypot(p11[0],p11[1])
    p1111=math.hypot(p11[2],p111)
    a.append(p1111)
    p211=math.hypot(p21[0],p21[1])
    p2111=math.hypot(p21[2],p211)
    a.append(p2111)
    p311=math.hypot(p31[0],p31[1])
    p3111=math.hypot(p31[2],p311)
    a.append(p3111)
     
    #对数组a排序,返回索引,第一个索引对应的训练数据和p的距离最短,那么p的label=该数据的label
    #相当于K=1,只找最近的那一个邻居
    a=np.array(a)
    b=a.argsort()
    print labels[b[0]]
     
     
    python示例3(自己写KNN算法,进阶示例)
    import numpy as np
    from math import sqrt
    import operator as opt
     
    #对数据进行[0,1]格式化处理
    def normData(dataSet):
       maxVals = dataSet.max(axis=0)
       minVals = dataSet.min(axis=0)
       ranges = maxVals - minVals
       retData = (dataSet - minVals) / ranges
       return retData, ranges, minVals
     
    def kNN(dataSet, labels, testData, k):
       distSquareMat = (dataSet - testData) ** 2 # 计算差值的平方
       distSquareSums = distSquareMat.sum(axis=1) # 求每一行的差值平方和
       distances = distSquareSums ** 0.5 # 开根号,得出每个样本到测试点的距离
       sortedIndices = distances.argsort() # 排序,得到排序后的下标
       indices = sortedIndices[:k] # 取最小的k个
       labelCount = {} # 存储每个label的出现次数
       for i in indices:
           label = labels[i]
          labelCount[label] = labelCount.get(label, 0) + 1 # 次数加一
       sortedCount = sorted(labelCount.items(), key=opt.itemgetter(1), reverse=True) 
       # 对label出现的次数从大到小进行排序
       return sortedCount[0][0] # 返回出现次数最大的label
     
    if __name__ == "__main__":
       dataSet = np.array([[2.0, 3], [6.0, 8], [3.0, 4]])
       normDataSet, ranges, minVals = normData(dataSet)   
       print normDataSet
       labels = ['a', 'a', 'c']
       testData = np.array([3.0, 5])
       normTestData = (testData - minVals) / ranges
       print normTestData
       result = kNN(normDataSet, labels, normTestData, 3)
       #k=1时就会取到最近的那一个点[3.0, 4]对应的label'c'
       print(result)
     
     
     
    学习参考:

    https://www.jianshu.com/p/33dbf9906ff2 介绍全面

    http://cda.pinggu.org/view/26396.html 图片数字识别的例子值得深思

     
     
     
  • 相关阅读:
    信息安全系统设计基础第二周学习总结(20135213)
    《深入理解计算机系统》第一节课课堂笔记(20135213)
    (20135213)信息安全系统设计基础第一周学习总结(共12课)课程(6~12)
    (20135213)信息安全系统设计基础第一周学习总结(共12课)课程(1~5)
    实验五 — — Java网络编程及安全
    20135220谈愈敏--信息安全系统设计基础期中总结
    20135220谈愈敏-第三章家庭作业
    20135220谈愈敏--信息安全系统设计基础第六周学习总结
    20135220谈愈敏-第二章家庭作业
    20135220谈愈敏--信息安全系统设计基础第五周学习总结
  • 原文地址:https://www.cnblogs.com/myshuzhimei/p/11724176.html
Copyright © 2011-2022 走看看