zoukankan      html  css  js  c++  java
  • KNN python实现

    参考网址:https://my.oschina.net/fengcunhan/blog/101281

    http://www.cnblogs.com/geniferology/p/what_is_kNN_algorithm.html

    kNN 的算法就是:
    
    在已知的 data points 中,逐一点检视(把這每一點叫作 P):
    1、首先计算「?」和 P 之间的距离
    2、所有距离计算之后,将他们由小至大 sort 好
    3、从 sort 好的序列,取最前的 k 个(即距离最接近「?」的 k 个点子)
    4、对这 k 个点,读出他们的 label(颜色)是什么,这是问题中已经知道的
    5、所有这些 labels(颜色),哪个出现最多?  (亦即是说,最接近「?」的 k 个点子,它们最普遍是什么颜色?)
    这出现次数最多的颜色,就是答案

    距离算法也有很多种,可以选择一种即可。。

    from numpy import *
    import operator
    
    def classify(inMat,dataSet,labels,k):
        dataSetSize=dataSet.shape[0]
                                       #KNN的算法核心就是欧式距离的计算,一下三行是计算待分类的点和训练集中的任一点的欧式距离
        diffMat=tile(inMat,(dataSetSize,1))-dataSet
        sqDiffMat=diffMat**2
        distance=sqDiffMat.sum(axis=1)**0.5
                                               #接下来是一些统计工作
        sortedDistIndicies=distance.argsort()  #将x中的元素从小到大排列,提取其对应的index(索引)
        classCount={}
        for i in range(k):                     #k =3 训练数据最近的K个点看看这几个点属于什么类型
            labelName=labels[sortedDistIndicies[i]]
            classCount[labelName]=classCount.get(labelName,0)+1;      #get(labelName,0)如果存在键labelName,返回它的value,否则返回0 
        sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
        return sortedClassCount[0][0]
    
    def file2Mat(testFileName,parammterNumber):
        fr=open(testFileName)
        lines=fr.readlines()
        lineNums=len(lines)
        resultMat=zeros((lineNums,parammterNumber))
        classLabelVector=[]
        for i in range(lineNums):
            line=lines[i].strip()
            itemMat=line.split('	')
            resultMat[i,:]=itemMat[0:parammterNumber]
            classLabelVector.append(itemMat[-1])
        fr.close()
        return resultMat,classLabelVector;
        
    #为了防止某个属性对结果产生很大的影响,所以有了这个优化,比如:10000,4.5,6.8 10000就对结果基本起了决定作用
    def autoNorm(dataSet):                         #dataset[n*3]
        minVals=dataSet.min(0)                     #数组每列的最小值 minVals[1*3]
        maxVals=dataSet.max(0)                     #数组每列的最小值
        ranges=maxVals-minVals
        normMat=zeros(shape(dataSet))              #全零数组(与原数据数组大小一致)
        size=normMat.shape[0]                      #列数
        normMat=dataSet-tile(minVals,(size,1))     #tile(minVals,(size,1))是等同于dataset大小的n*3数组
        normMat=normMat/tile(ranges,(size,1))      #差距除以范围:比率?
        return normMat,minVals,ranges
    def test(trainigSetFileName,testFileName):
        trianingMat,classLabel=file2Mat(trainigSetFileName,3)
        trianingMat,minVals,ranges=autoNorm(trianingMat)  #正则矩阵每个数据减去最小值除以数据变化范围宽度
        testMat,testLabel=file2Mat(testFileName,3)
        testSize=testMat.shape[0]                  #取第一列的维度
        errorCount=0.0
        for i in range(testSize):                  #测试集一个一个数据测试
            result=classify((testMat[i]-minVals)/ranges,trianingMat,classLabel,3)
            if(result!=testLabel[i]):
                errorCount+=1.0
        errorRate=errorCount/(float)(len(testLabel))
        return errorRate;
    if __name__=="__main__":
        errorRate=test('train.txt','test.txt')
        print("the error rate is :%f"%(errorRate))

    训练集&测试集特点(最后一列是标签):

    40920	8.326976	0.953952	3
    14488	7.153469	1.673904	2
    26052	1.441871	0.805124	1
    75136	13.147394	0.428964	1
    38344	1.669788	0.134296	1
    72993	10.141740	1.032955	1
    35948	6.830792	1.213192	3
    42666	13.276369	0.543880	3
    67497	8.631577	0.749278	1
    35483	12.273169	1.508053	3

      

  • 相关阅读:
    VirtualBox 命令行操作
    [大数据入门]实战练习 安装Cloudera-Hadoop集群
    DB2 Package Issues and Solution
    关于《阿里巴巴Java开发规约》插件的安装与使用
    Spring学习笔记:Spring概述,第一个IoC依赖注入案例
    SpringMVC:系统认识一下maven
    利用JS提交表单的几种方法和验证(必看篇)
    a标签调用js的几种方法
    使用BigDecimal进行精确运算
    Maven中settings.xml的配置项说明
  • 原文地址:https://www.cnblogs.com/hozhangel/p/8073557.html
Copyright © 2011-2022 走看看