zoukankan      html  css  js  c++  java
  • 机器学习实战python3 K近邻(KNN)算法实现

       台大机器技法跟基石都看完了,但是没有编程一直,现在打算结合周志华的《机器学习》,撸一遍机器学习实战, 原书是python2 的,但是本人感觉python3更好用一些,所以打算用python3 写一遍。python3 与python2 不同的地方会在程序中标出。

    代码及数据:https://github.com/zle1992/MachineLearningInAction

    k-近邻算法

    优点:精度高、对异常值不敏感、无数据输入假定。

    缺点:计算复杂度高、空间复杂度高。对K的取值敏感!!!

    适用数据范围:数值型和标称型。

    1 创建数据

    1 def createDataSet():
    2     group = np.array([
    3     [1.0,1.1],
    4     [1.0,1.0],
    5     [0,0],
    6     [0,0.1] ])
    7     labels = ['A','A','B','B']
    8     return group ,labels

    2 算法

    对未知类别属性的数据集中的每个点依次执行以下操作:
    (1)计算已知类别数据集中的点与当前点之间的距离;
    (2)按照距离递增次序排序;
    (3)选取与当前点距离最小的k个点;
    (4)确定前k个点所在类别的出现频率;
    (5)返回前k个点出现频率最高的类别作为当前点的预测分类。

     1 def classify0(intX,dataX,labels,k):
     2     dataSize = dataX.shape[0] #行数,
     3     diffMat = np.tile(intX,(dataSize,1)) - dataX #intX 复制4行,形成矩阵,并计算距离差
     4     sqDiffMat = diffMat * diffMat  # 等价于 diffMat ** 2
     5     sqDistence = sqDiffMat.sum(axis = 1) #按行相加
     6     distence = sqDistence ** 0.5 # 开根号 , distence是array
     7     sortedDistenceIndicies = distence.argsort() # 返回排序后的下标
     8     classCount = {} # 字典 key is label , val is count  
     9     for i in range(k):
    10         voteIlabel = labels[sortedDistenceIndicies[i]] # 排名第i的label
    11         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #   .get(voteIlabel,0)   if no exit return 0
    12     sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True) 
    13     # 根据iteritems 的第1个元素排序 即字典中的val ,第0个 是key  #reverse定义为True时将按降序排列
    14     #!!!!! 与py2 的差别!! python3 字典的item 就是迭代对象
    15     return sortedClassCount[0][0]  # 返回排序后的字典的前一个(从0开始)

    3在约会网站上使用k近邻算法

    (1)收集数据:提供文本文件。
    (2)准备数据:使用Python解析文本文件。
    (3)分析数据:使用Matplotlib画二维扩散图。
    (4)训练算法:此步骤不适用于k-近邻算法。
    (5)测试算法:使用海伦提供的部分数据作为测试样本。测试样本和非测试样本的区别在于:测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。
    (6)使用算法:产生简单的命令行程序,然后海伦可以输入一些特征数据以判断对方是否为自己喜欢的类型。

    约会数据存放在文本文件datingTestSet.txt中,每个样本数据占据一行,总共有1000行。主要包含以下3种特征:

    口每年获得的飞行常客里程数
    口玩视频游戏所耗时间百分比
    口每周消费的冰淇淋公升数

    3-1将文本记录到转换NumPy的解析程序

     1 def file2matrix(filename):
     2     with open(filename,mode = "r") as fr:
     3         arrayOLines = fr.readlines()
     4         numberOfLine = len(arrayOLines)
     5         returnMat = np.zeros((numberOfLine,3))
     6         labels = []
     7         index = 0
     8         for line in arrayOLines:
     9             listFromLine = line.split("	")
    10             returnMat[index,:] = listFromLine[0:3]
    11             labels.append(int(listFromLine[-1]))  #-1 倒数第一个
    12             index = index + 1
    13         return returnMat,labels

    3-2分析数据:使用Matplotlib创建散点图

     1     k = 3
     2     #读入数据
     3     filename = "datingTestSet2.txt"
     4     dataX,labels = file2matrix(filename)
     5     fig = plt.figure()
     6     ax = fig.add_subplot(111)
     7     ax.scatter(dataX[:,0],dataX[:,1],c = 15 *np.array(labels),s = 15*np.array(labels)) # c 是颜色序列!!!! s 是大小
     8     ax = fig.add_subplot(121)
     9     ax.scatter(dataX[:,0],dataX[:,2],c = 15 *np.array(labels),s = 15*np.array(labels)) # c 是颜色序列!!!! s 是大小
    10     ax = fig.add_subplot(131)
    11     ax.scatter(dataX[:,1],dataX[:,2],c = 15 *np.array(labels),s = 15*np.array(labels)) # c 是颜色序列!!!! s 是大小

    3-3准备数据:归一化数值

     1 def autoNorm(dataX):
     2     #  归一化公式 newVal = (oldval-min)/(max - min)
     3     minVals = dataX.min(axis=0)#返回的是最每一列中的最小的元素 min(1)#返回的是最每一行中的最小的元素
     4     maxVals = dataX.max(0)#返回的是最每一列中的最大的元素
     5     ranges = maxVals - minVals #
     6     rows = dataX.shape[0]
     7     newVal = dataX - tile(minVals,(rows,1)) #(oldval-min)
     8     ##   tile(minVals,(rows,1))复制rows行 ,列数为minvals列数的一倍
     9     newVal = newVal/tile(ranges,(rows,1)) #(oldval-min)/(max - min)
    10     return newVal,ranges,minVals

    3-4测试算法:作为完整程序验证分类器

     1 def datingClassTest():
     2     hoRatio = 0.1 # 10% of data as  test
     3     #读入数据
     4     filename = "datingTestSet2.txt"
     5     dataX ,labels = file2matrix(filename)
     6     #归一化
     7     normMat,ranges,minVals = autoNorm(dataX)
     8 
     9     m = dataX.shape[0] #numbers of rows
    10     numTestVecs = int(m * hoRatio) #numbers of test 
    11     errorcount = 0 #initialize number of errors 
    12     for i in range(numTestVecs):
    13         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],labels[numTestVecs:m],5)#前10%作为测试数据
    14     #   print("the classifier predict %d, the real answer is :%d" %((classifierResult),labels[i]))
    15         if(classifierResult != labels[i]):
    16             errorcount = errorcount + 1.0
    17     print("error rate :%f" %((errorcount)/(numTestVecs)))

    3-5使用算法:构建完整可用系统

     此程序需要在命令行下进行,因为要从命名行读入数据

     1 def classifyPerson():
     2     resultList = ["第一类","第二类","第三类"] #output lables
     3     
     4     percentTats = float(input("玩游戏消耗的时间"))###########!!!!!!! pyhton3 为input#####!!!!!!!!!!!!!!!!!
     5     ffilm = float(input("每年获得的飞行里程数"))###########!!!!!!! pyhton2 为raw_input#####!!!!!!!!!!!!!!!!!
     6     iceCream = float(input("每周消费冰淇淋"))
     7         #读入数据
     8     filename = "datingTestSet2.txt"
     9     dataX ,labels = file2matrix(filename)
    10     #归一化
    11     normMat,ranges,minVals = autoNorm(dataX)
    12 
    13     test_list = np.array([percentTats,ffilm,iceCream])
    14     classifierResult = classify0(test_list,dataX,labels,3)
    15     print("你喜欢的类别:" + resultList[classifierResult])

     4手写识别系统

    (1)收集数据:提供文本文件。

    (2)准备数据:编写函数classify0(),将图像格式转换为分类器使用的list格式。

    (3)分析数据:在Python命令提示符中检查数据,确保它符合要求。

    (4)训练算法:此步骤不适用于裕近邻算法。

    (5)测试算法:编写函数使用提供的部分数据集作为测试样本,测试样本与非测试样本的区别在于测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。

    (6)使用算法:本例没有完成此步骤,若你感兴趣可以构建完整的应用程序,从图像中提取数字,并完成数字识别,美国的邮件分拣系统就是一个实际运行的类似系统。

    4-1准备数据:将图像转换为测试向量

    1 def img2vector(filename):
    2     returnVect = np.zeros((1,1024)) #initilize the vec
    3     with open (filename,mode = "r") as fr:   #########!!!!!!!!与python2 不同   !!!!!!!!
    4         lineStr = fr.readlines()             #########!!!!!!!!与python2 不同   !!!!!!!!
    5         for i in range(32):                  #########!!!!!!!!与python2 不同   !!!!!!!!
    6             for j in range(32):                 #########!!!!!!!!与python2 不同   !!!!!!!!
    7                 returnVect[0,i*32+j] = lineStr[i][j]        #########!!!!!!!!与python2 不同   !!!!!!!!
    8     return returnVect

    4-2测试算法:使用k-近邻算法识别手写数字

     1 def handwritingClassTest():
     2     hwlables = []
     3     trainingFileList = os.listdir("digits\trainingDigits")######!!!!!!!!与python2 不同   !py3 need import os 
     4     m = len(trainingFileList) # number of trianfiles
     5     trainMat = np.zeros((m,1024))
     6     for i in range(m):
     7         fileNameStr = trainingFileList[i]   #循环操作 把每一个txt文件转换为矩阵
     8         fileStr = fileNameStr.split(".")[0]  # 把文件名与后缀名分开,提取文件名
     9         classNumStr = fileStr.split("_")[0]  #提取文件名中的标签
    10         hwlables.append(classNumStr)  #把类别标签添加到类别List
    11         trainMat[i,:] = img2vector("digits\trainingDigits\%s" %fileNameStr)
    12 
    13     testFileList = os.listdir("digits\testDigits")  #处理测试文件 
    14     mtest = len(testFileList)
    15     errorcount = 0.0 
    16     for i in range(mtest):
    17         fileNameStr = testFileList[i]
    18         fileStr = fileNameStr.split(".")[0]
    19         classNumStr = fileStr.split("_")[0]
    20         vectorUderTest = img2vector("digits\testDigits\%s" %fileNameStr)
    21         classifierResult = classify0(vectorUderTest,trainMat,hwlables,3)
    22         if(classifierResult != classNumStr) :
    23             errorcount += 1.0
    24     print("errorcount:%d" %errorcount)
    25     print("error rate :%f" %float((errorcount/mtest)))

    final  主程序

    需要运行哪个程序就打开哪个程序的注释即可

    1 if __name__ == '__main__':
    2     #simpletest()   # 测试createDataSet 和classify0 这2个函数
    3     plot()   #画datingTestSet2.txt这个数据的图像
    4     #handwritingClassTest()  #手写识别程序
    5     #datingClassTest()  #约会 识别程序
    6     #classifyPerson()  #从命令行读入数据
    7    
    8    
    def simpletest():  # 测试createDataSet 和classify0 这2个函数
        intX = [0.0,0.0]
        k = 3
        dataX,labels = createDataSet() 
        a = classify0(intX,dataX,labels,k)
        print("dataX:" ,dataX)
        print("labels:",labels)
        print("predict:" ,a)




  • 相关阅读:
    LDAP 总结
    关于OpenLDAPAdmin管理页面提示“This base cannot be created with PLA“问题. Strong Authentication Required问题
    PHP 7.0 5.6 下安裝 phpLDAPadmin 发生错误的修正方法
    ldap、additional info: no global superior knowledge
    ldap安装配置过程中遇到的错误,以及解决方法
    转: LDAP有啥子用? 用户认证
    Mac 安装 brew
    go test 单元函数测试
    haproxy httpcheck with basic auth
    architecture and business process modelling
  • 原文地址:https://www.cnblogs.com/zle1992/p/6841493.html
Copyright © 2011-2022 走看看