#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2018/11/7 14:50 # @Author : gylhaut # @Site : "http://www.cnblogs.com/gylhaut/" # @File : KNNAlgorithm.py # @Software: PyCharm # coding:utf-8 from numpy import * import operator ##给出训练数据以及对应的类别 def createDataSet(): group = array([[1.0, 2.0], [1.2, 0.1], [0.1, 1.4], [0.3, 3.5]]) labels = ['A', 'A', 'B', 'B'] return group, labels ###通过KNN进行分类 def classify(input, dataSet, label, k): ''' :param input: test集 :param dataSet: 训练集 :param label: 训练output :param k: k值选择 :return: ''' dataSize = dataSet.shape[0] # 4 ####计算欧式距离 # print(tile(input, (dataSize, 1))) diff = tile(input, (dataSize, 1)) - dataSet sqdiff = diff ** 2 squareDist = sum(sqdiff, axis=1) ###行向量分别相加,从而得到新的一个行向量 dist = squareDist ** 0.5 #print(dist) ##对距离进行排序 sortedDistIndex = argsort(dist) ##argsort()根据元素的值从小到大对元素进行排序,返回下标 #print(sortedDistIndex) classCount = {} for i in range(k): voteLabel = label[sortedDistIndex[i]] #print(voteLabel) ###对选取的K个样本所属的类别个数进行统计 classCount[voteLabel] = classCount.get(voteLabel, 0) + 1 ###选取出现的类别次数最多的类别 #print(classCount) maxCount = 0 for key, value in classCount.items(): if value > maxCount: maxCount = value classes = key return classes from numpy import * dataSet,labels = createDataSet() input = array([1.1,0.3]) K = 3 output = classify(input,dataSet,labels,K) print("测试数据为:",input,"分类结果为:",output)