zoukankan      html  css  js  c++  java
  • KNN算法

    1.算法讲解

    KNN算法是一个最基本、最简单的有监督算法,基本思路就是给定一个样本,先通过距离计算,得到这个样本最近的topK个样本,然后根据这topK个样本的标签,投票决定给定样本的标签;

    训练过程:只需要加载训练数据;

    测试过程:通过之前加载的训练数据,计算测试数据集中各个样本的标签,从而完成测试数据集的标注;

    2.代码

    具体代码如下:

    #!/usr/bin/env/ python
    # -*- coding: utf-8 -*-
    
    import csv
    import random
    from matplotlib import pyplot as plt
    import numpy as np
    from sklearn.decomposition import PCA
    
    class KNN(object):
        def __init__(self):
            self._trainData = None
            self._trainDataLabel = None
    
        # 计算距离
        def _computerDist(self,testData):
            m = testData.shape[0]
            n = self._trainData.shape[0]
            dist = np.zeros((m,n))
            for i in range(m):
                for j in range(n):
                    dist[i][j] = np.sum( (testData[i,:] - self._trainData[j,:])**2 )
            return dist
    
        # 模型训练,knn只需要加载训练数据集
        def train(self,dataset):
            self._trainData = dataset[:,0:-1]
            self._trainDataLabel = np.array(dataset[:,-1],dtype = np.int)
    
        # 预测测试数据集
        def predict(self,testData,topK = 3):
            dist = self._computerDist(testData)
            num_test = testData.shape[0]
            predLable = np.zeros(num_test)
    
            for i in range(num_test):
                labelList = []
                # 得到前topK样本的索引
                idxList = np.argsort(dist[i,:])[:topK].tolist()
                # 根据这些索引,得到对应的标签
                labelList = self._trainDataLabel[idxList]
                # 统计各个标签数目
                counts = np.bincount(labelList)
                # 将标签数目最大的标签值作为样本的标签
                predLable[i] = np.argmax(counts)
            return predLable
    
        # 测试准确率
        def test(self,testData,testLabel,topK = 3):
            predLabel = self.predict(testData,topK)
            predLabel = np.array(predLabel,dtype = int)
            num_correct = np.sum(predLabel == testLabel)
            num_test = testLabel.shape[0]
            accuracy = float(num_correct) / num_test
            print "testLabel:" + str(testLabel)
            print "predLabel:" + str(predLabel)
            print "get: %d / % d => accuracy: %f" %(num_correct,num_test,accuracy)
            return predLabel
    
        # 画出结果图
        def plotResult(self,testData,predLable):
            X = self._trainData
            y = self._trainDataLabel
    
            pca = PCA(n_components=2)
            X_r = pca.fit(X).transform(X)
    
            test_r = pca.fit(testData).transform(testData)
    
            plt.figure()
            for c, i in zip("rgb", [0, 1, 2]):
                plt.scatter(X_r[y == i, 0], X_r[y == i, 1], c=c)
                plt.scatter(test_r[predLable == i,0],test_r[predLable == i,1],s= 30,c = c,marker = 'D')
            plt.legend()
            plt.title('KNN of IRIS dataset')
            plt.show()
    
        # 加载数据集
        def loadDataSet(self,fileName,splitRatio = 0.9):
            lines = csv.reader(open(fileName,"rb") )
            dataset = list(lines)
            for i in range(len(dataset)):
                dataset[i] = [float(x) for x in dataset[i]]
    
            trainSize = int(len(dataset) * splitRatio)
            random.shuffle(dataset)
            trainData = np.array(dataset[:trainSize])
            testData = np.array(dataset[trainSize:])
            return trainData,testData
    
    if __name__ == "__main__":
        fileName = 'iris.csv'
        KNNobj = KNN()
        trainData,testData = KNNobj.loadDataSet(fileName,0.8)
        # 抽取出测试数据
        testdata = testData[:,0:-1]
        # 抽取出测试标签数据
        testdataLabel = np.array(testData[:,-1],dtype = int)
        # 训练模型
        KNNobj.train(trainData)
        # 测试模型
        predLabel = KNNobj.test(testdata,testdataLabel,3)
        # 画出结果分布
        KNNobj.plotResult(testdata,predLabel)
    

    3.结果分析

    本实例中,训练数据样本量为120个,测试数据样本量为30个,topK=3;

    运行结果如下:

    get: 29 /  30 => accuracy: 0.966667
    

    结果分布图如下所示:

    其中圆心点为训练数据,菱形点为测试数据;不同颜色代表不同的类;

    4.参考链接

    Comparison of LDA and PCA 2D projection of Iris dataset

  • 相关阅读:
    执行超过1个小时的SQL语句
    非周一回写销售预测
    openLDAP
    Windows下使用性能监视器监控SqlServer的常见指标
    ORA-01720: grant option does not exist for 'xxx.xxxx' (ORA-01720 ‘XXX’ 不存在授权选项)
    117 FP页面无法查看 此错误是JDK8.0.0.0版本的一个BUG,会导致工单重复回写,
    KPI
    Quatrz + Spring
    windows 脚本
    Spring集成Redis
  • 原文地址:https://www.cnblogs.com/zhbzz2007/p/5528091.html
Copyright © 2011-2022 走看看