zoukankan      html  css  js  c++  java
  • 基于python 实现KNN 算法

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    # @Time    : 2018/11/7 14:50
    # @Author  : gylhaut
    # @Site    : "http://www.cnblogs.com/gylhaut/"
    # @File    : KNNAlgorithm.py
    # @Software: PyCharm
    
    # coding:utf-8
    
    from numpy import *
    import operator
    
    
    ##给出训练数据以及对应的类别
    def createDataSet():
        group = array([[1.0, 2.0], [1.2, 0.1], [0.1, 1.4], [0.3, 3.5]])
        labels = ['A', 'A', 'B', 'B']
        return group, labels
    
    
    ###通过KNN进行分类
    def classify(input, dataSet, label, k):
        '''
    
        :param input: test集
        :param dataSet: 训练集
        :param label: 训练output
        :param k: k值选择
        :return:
        '''
        dataSize = dataSet.shape[0] # 4
        ####计算欧式距离
        # print(tile(input, (dataSize, 1)))
        diff = tile(input, (dataSize, 1)) - dataSet
    
        sqdiff = diff ** 2
        squareDist = sum(sqdiff, axis=1)  ###行向量分别相加,从而得到新的一个行向量
        dist = squareDist ** 0.5
        #print(dist)
        ##对距离进行排序
        sortedDistIndex = argsort(dist)  ##argsort()根据元素的值从小到大对元素进行排序,返回下标
        #print(sortedDistIndex)
        classCount = {}
        for i in range(k):
            voteLabel = label[sortedDistIndex[i]]
            #print(voteLabel)
            ###对选取的K个样本所属的类别个数进行统计
            classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
        ###选取出现的类别次数最多的类别
        #print(classCount)
        maxCount = 0
        for key, value in classCount.items():
            if value > maxCount:
                maxCount = value
                classes = key
    
        return classes
    
    from numpy import *
    dataSet,labels = createDataSet()
    input = array([1.1,0.3])
    K = 3
    output = classify(input,dataSet,labels,K)
    print("测试数据为:",input,"分类结果为:",output)
  • 相关阅读:
    android系统属性获取及设置
    Android Strings.xml To CSV / Excel互转
    android adb命令 抓取系统各种 log
    Android开源日志库Logger的使用
    解决git仓库从http转为ssh所要处理的问题
    PHP中var_dump
    oracle文字与格式字符串不匹配的解决
    Apache服务器和tomcat服务器有什么区别?
    【手把手教你Maven】构建过程
    Spring MVC页面重定向
  • 原文地址:https://www.cnblogs.com/gylhaut/p/9922994.html
Copyright © 2011-2022 走看看