zoukankan      html  css  js  c++  java
  • k近邻算法python实现 -- 《机器学习实战》

    
    
      1 '''
      2 Created on Nov 06, 2017
      3 kNN: k Nearest Neighbors
      4 
      5 Input:      inX: vector to compare to existing dataset (1xN)
      6             dataSet: size m data set of known vectors (NxM)
      7             labels: data set labels (1xM vector)
      8             k: number of neighbors to use for comparison (should be an odd number)
      9 
     10 Output:     the most popular class label
     11 
     12 @author: Liu Chuanfeng
     13 '''
     14 import operator
     15 import numpy as np
     16 import matplotlib.pyplot as plt
     17 from os import listdir
     18 
     19 def classify0(inX, dataSet, labels, k):
     20     dataSetSize = dataSet.shape[0]
     21     diffMat = np.tile(inX, (dataSetSize,1)) - dataSet
     22     sqDiffMat = diffMat ** 2
     23     sqDistances = sqDiffMat.sum(axis=1)
     24     distances = sqDistances ** 0.5
     25     sortedDistIndicies = distances.argsort()
     26     classCount = {}
     27     for i in range(k):
     28         voteIlabel = labels[sortedDistIndicies[i]]
     29         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
     30     sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
     31     return sortedClassCount[0][0]
     32 
     33 #数据预处理,将文件中数据转换为矩阵类型
     34 def file2matrix(filename):
     35     fr = open(filename)
     36     arrayLines = fr.readlines()
     37     numberOfLines = len(arrayLines)
     38     returnMat = np.zeros((numberOfLines, 3))
     39     classLabelVector = []
     40     index = 0
     41     for line in arrayLines:
     42         line = line.strip()
     43         listFromLine = line.split('	')
     44         returnMat[index,:] = listFromLine[0:3]
     45         classLabelVector.append(int(listFromLine[-1]))
     46         index += 1
     47     return returnMat, classLabelVector
     48 
     49 #数据归一化处理:由于矩阵各列数据取值范围的巨大差异导致各列对计算结果的影响大小不一,需要归一化以保证相同的影响权重
     50 def autoNorm(dataSet):
     51     maxVals = dataSet.max(0)
     52     minVals = dataSet.min(0)
     53     ranges = maxVals -  minVals
     54     m = dataSet.shape[0]
     55     normDataSet = (dataSet - np.tile(minVals, (m, 1))) / np.tile(ranges, (m, 1))
     56     return normDataSet, ranges, minVals
     57 
     58 #约会网站测试代码
     59 def datingClassTest():
     60     hoRatio = 0.10
     61     datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
     62     normMat, ranges, minVals = autoNorm(datingDataMat)
     63     m = normMat.shape[0]
     64     numTestVecs = int(m * hoRatio)
     65     errorCount = 0.0
     66     for i in range(numTestVecs):
     67         classifyResult = classify0(normMat[i,:], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
     68         print('theclassifier came back with: %d, the real answer is: %d' % (classifyResult, datingLabels[i]))
     69         if ( classifyResult != datingLabels[i]):
     70             errorCount += 1.0
     71         print ('the total error rate is: %.1f%%' % (errorCount/float(numTestVecs) * 100))
     72 
     73 #约会网站预测函数
     74 def classifyPerson():
     75     resultList = ['not at all', 'in small doses', 'in large doses']
     76     percentTats = float(input("percentage of time spent playing video games?"))
     77     ffMiles = float(input("frequent flier miles earned per year?"))
     78     iceCream = float(input("liters of ice cream consumed per year?"))
     79     datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
     80     normMat, ranges, minVals = autoNorm(datingDataMat)
     81     inArr = np.array([ffMiles, percentTats, iceCream])
     82     classifyResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)
     83     print ("You will probably like this persoon:", resultList[classifyResult - 1])
     84 
     85 
     86 #手写识别系统#============================================================================================================
     87 #数据预处理:输入图片为32*32的文本类型,将其形状转换为1*1024
     88 def img2vector(filename):
     89     returnVect = np.zeros((1, 1024))
     90     fr = open(filename)
     91     for i in range(32):
     92         lineStr = fr.readline()
     93         for j in range(32):
     94             returnVect[0, 32*i+j] = int(lineStr[j])
     95     return returnVect
     96 
     97 #手写数字识别系统测试代码
     98 def handwritingClassTest():
     99     hwLabels = []
    100     trainingFileList = listdir('C:\Private\PycharmProjects\Algorithm\kNNdigits\traingDigits')
    101     m = len(trainingFileList)
    102     trainingMat = np.zeros((m, 1024))
    103     for i in range(m):                                     #|
    104         fileNameStr = trainingFileList[i]                  #|
    105         fileName = fileNameStr.split('.')[0]               #| 获取训练集路径下每一个文件,分割文件名,将第一个数字作为标签存储在hwLabels中
    106         classNumber = int(fileName.split('_')[0])          #|
    107         hwLabels.append(classNumber)                       #|
    108         trainingMat[i,:] = img2vector('C:\Private\PycharmProjects\Algorithm\kNNdigits\traingDigits\%s' % fileNameStr)    #变换矩阵形状: from 32*32 to 1*1024
    109     testFileList = listdir('C:\Private\PycharmProjects\Algorithm\kNNdigits\testDigits')
    110     errorCount = 0.0
    111     mTest = len(testFileList)
    112     for i in range(mTest):              #同训练集
    113         fileNameStr = testFileList[i]
    114         fileName = fileNameStr.split('.')[0]
    115         classNumber = int(fileName.split('_')[0])
    116         vectorUnderTest = img2vector('C:\Private\PycharmProjects\Algorithm\kNNdigits\testDigits\%s' % fileNameStr)
    117         classifyResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)   #计算欧氏距离并分类,返回计算结果
    118         print ('The classifier came back with: %d, the real answer is: %d' % (classifyResult, classNumber))
    119         if (classifyResult != classNumber):
    120             errorCount += 1.0
    121     print ('The total number of errors is: %d' % (errorCount))
    122     print ('The total error rate is: %.1f%%' % (errorCount/float(mTest) * 100))
    123 
    124 # Simple unit test of func: file2matrix()
    125 #datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
    126 #print (datingDataMat)
    127 #print (datingLabels)
    128 
    129 # Usage of figure construction of matplotlib
    130 #fig=plt.figure()
    131 #ax = fig.add_subplot(111)
    132 #ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels))
    133 #plt.show()
    134 
    135 # Simple unit test of func: autoNorm()
    136 #normMat, ranges, minVals = autoNorm(datingDataMat)
    137 #print (normMat)
    138 #print (ranges)
    139 #print (minVals)
    140 
    141 # Simple unit test of func: img2vector
    142 #testVect = img2vector('C:\Private\PycharmProjects\Algorithm\kNNdigits\testDigits\0_13.txt')
    143 #print (testVect[0, 32:63] )
    144 
    145 #约会网站测试
    146 datingClassTest()
    147 
    148 #约会网站预测
    149 classifyPerson()
    150 
    151 #手写数字识别系统预测
    152 handwritingClassTest()
    
    
    Output:

    theclassifier came back with: 3, the real answer is: 3
    the total error rate is: 0.0%
    theclassifier came back with: 2, the real answer is: 2
    the total error rate is: 0.0%
    theclassifier came back with: 1, the real answer is: 1
    the total error rate is: 0.0%

    ...

    theclassifier came back with: 2, the real answer is: 2
    the total error rate is: 4.0%
    theclassifier came back with: 1, the real answer is: 1
    the total error rate is: 4.0%
    theclassifier came back with: 3, the real answer is: 1
    the total error rate is: 5.0%

    percentage of time spent playing video games?10
    frequent flier miles earned per year?10000
    liters of ice cream consumed per year?0.5
    You will probably like this persoon: in small doses

    ...

    The classifier came back with: 9, the real answer is: 9
    The total number of errors is: 27
    The total error rate is: 6.8%

     Reference:

    《机器学习实战》

  • 相关阅读:
    RN8209校正软件开发心得(1)
    Chrome 31版本导出Excel问题
    ComBox选择
    网页设计的一般步骤
    .NET一套开发工具
    关于用sql语句实现一串数字位数不足在左侧补0的技巧
    python jieba模块详解
    python内置函数详细描述与实例演示
    Markdown的基本语法记录
    python configparser模块详解
  • 原文地址:https://www.cnblogs.com/knownx/p/7806231.html
Copyright © 2011-2022 走看看