zoukankan      html  css  js  c++  java
  • KNN算法

    KNN算法python实现

    from operator import itemgetter
    
    import numpy as np
    
    
    # 获取训练样本和样本标签
    def createDataSet():
        groups = np.array(([1, 1.1], [1, 1], [0, 0], [0, 0.1]))
        labels = np.array(['A', 'A', 'B', 'B'])
        return groups, labels
    
    # inx表示要判断的样本
    # dataSet表示训练数据集
    # labels表示样本类型
    # k表示使用前k个来判断
    def classify(inx, dataSet, labels, k):
        # 获取训练集中的第一维数字
        dataSetSize = dataSet.shape[0]
        diffMat = np.tile(inx, (dataSetSize, 1)) - dataSet
        sqDiffMat = diffMat ** 2
        # 每行的元素相加
        sqDistance = sqDiffMat.sum(axis=1)
        # 开方
        distance = sqDistance ** 0.5
        # 返回排序索引
        sortedDistIndics = distance.argsort()
        classCount = {}
        for i in range(k):
            voteLabel = labels[sortedDistIndics[i]]
            classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
        sortedClasCount = sorted(classCount.items(), key=itemgetter(1), reverse=True)
        return sortedClasCount[0][0]
    
    
    # 这就是动态扩展语法糖
    # inxo = [float(s) if re.match(re.compile(r'd+'), s) else s for s in sys.argv if not s.endswith('.py')]
    # inx = np.array(inxo)
    # print()
    # dataSet, labels = createDataSet()
    # ftypes = classify(inx, dataSet, labels, 3)
    # print(ftypes)
    
    # 将文本记录转换Numpy
    def filematrix(filename):
        fr = open(filename)
        arrayOLines = fr.readlines()
        numberOfLines = len(arrayOLines)
        returnMat = np.zeros((numberOfLines, 3))
        classLabelVector = []
        index = 0
        for line in arrayOLines:
            line = line.strip()
            listFromLine = line.split('	')
            returnMat[index, :] = listFromLine[0:3]
            classLabelVector.append(int(listFromLine[-1]))
            index += 1
        return returnMat, classLabelVector
    
    
    # 归一化操作
    def autoNorm(dataSet):
        # 获得数据集中每一列的最小值
        minVals = dataSet.min(0)
        maxVals = dataSet.max(0)
        # 范围
        rangs = maxVals - minVals
        normDataSet = np.zeros(np.shape(dataSet))
        m = dataSet.shape[0]
        normDataSet = dataSet - np.tile(minVals, (m, 1))
        normDataSet = normDataSet / np.tile(rangs, (m, 1))
        return normDataSet, rangs, minVals
    
    
    def datingClassTest():
        # 设置训练数据和测试数据比例
        hoRatio = 0.10
        # 读取文件数据
        datingDataMat, datingLabels = filematrix('datingTestSet.txt')
        # 归一化操作
        normMat, rangs, minVals = autoNorm(datingDataMat)
        # 获得当前矩阵的行数
        m = normMat.shape[0]
        # 设置多少个数据用来测试
        numTestVecs = int(m * hoRatio)
        # 统计错误率
        errorCount = 0.0
        for i in range(numTestVecs):
            # 获取测试数据的每个预测类型
            classifierResult = classify(normMat[i, :], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
    
            print('the classifier came back with: %d, the real answer is:%d' % (classifierResult, datingLabels[i]))
            # 如果预测类型和实际类型不相同
            if (classifierResult != datingLabels[i]):
                errorCount += 1
        print('the total error rate is:%f' % (errorCount / float(numTestVecs)))
    
    
    datingClassTest()
    # import matplotlib.pyplot as plt
    # fig = plt.figure()
    # ax = fig.add_subplot(111)
    # ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2])
    # ax.scatter(datingDataMat[:, 0], datingDataMat[:, 2], 15.0 * np.array(datingLabels), 15.0 * np.array(datingLabels))
    # plt.show()
  • 相关阅读:
    JSP自定义标签_用简单标签控制标签体执行10次
    JSP自定义标签_用简单标签实现控制标签体是否执行
    eclipse 使用lombok 精简java bean
    转 :关于springmvc使用拦截器
    转: spring静态注入
    spring 4.0+quartz2.2 实现持久化
    排除maven jar冲突 maven tomcat插件启动报错 filter转换异常
    转 Quartz将Job持久化所需表的说明
    转 maven jetty 插件
    ORA-14300: 分区关键字映射到超出允许的最大分区数的分区
  • 原文地址:https://www.cnblogs.com/09120912zhang/p/7989773.html
Copyright © 2011-2022 走看看