zoukankan      html  css  js  c++  java
  • 机器学习算法--KNN算法

    KNN算法原理

    KNN(K-Nearest Neighbor)最邻近分类算法是数据挖掘分类(classification)技术中最简单的算法之一,其指导思想是”近朱者赤,近墨者黑“,即由你的邻居来推断出你的类别。

     KNN最邻近分类算法的实现原理:为了判断未知样本的类别,以所有已知类别的样本作为参照,计算未知样本与所有已知样本的距离,从中选取与未知样本距离最近的K个已知样本,根据少数服从多数的投票法则(majority-voting),将未知样本与K个最邻近样本中所属类别占比较多的归为一类。

    Python实现KNN算法

       

     1 import numpy as np
     2 import operator
     3 
     4 def createDataset():
     5   #四组二维特征
     6   group = np.array([[5,115],[7,106],[56,11],[66,9]])
     7   #四组对应标签
     8   labels = ('动作片','动作片','爱情片','爱情片')
     9   return group,labels
    10 
    11 def classify(intX,dataSet,labels,k):
    12   '''
    13   KNN算法
    14   '''
    15   #numpy中shape[0]返回数组的行数,shape[1]返回列数
    16   dataSetSize = dataSet.shape[0]
    17   #将intX在横向重复dataSetSize次,纵向重复1次
    18   #例如intX=([1,2])--->([[1,2],[1,2],[1,2],[1,2]])便于后面计算
    19   diffMat = np.tile(intX,(dataSetSize,1))-dataSet
    20   #二维特征相减后乘方
    21   sqdifMax = diffMat**2
    22   #计算距离
    23   seqDistances = sqdifMax.sum(axis=1)
    24   distances = seqDistances**0.5
    25   print ("distances:",distances)
    26   #返回distance中元素从小到大排序后的索引
    27   sortDistance = distances.argsort()
    28   print ("sortDistance:",sortDistance)
    29   classCount = {}
    30   for i in range(k):
    31   #取出前k个元素的类别
    32   voteLabel = labels[sortDistance[i]]
    33   classCount[voteLabel] = classCount.get(voteLabel,0)+1
    34   #dict.get(key,default=None),字典的get()方法,返回指定键的值,如果值不在字典中返回默认值。
    35   #reverse降序排序字典
    36 
    37   #classCount.iteritems()将classCount字典分解为元组列表,operator.itemgetter(1)按照第二个元素的次序对元组进行排序,reverse=True是逆序,即按照从大到小的顺序排列
    38   sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
    39   #结果sortedClassCount = [('动作片', 2), ('爱情片', 1)]
    40   print ("sortedClassCount:",sortedClassCount)
    41   print("===>>>%s",classCount.items())
    42   return sortedClassCount[0][0]
    43 if __name__ == '__main__':
    44   group,labels = createDataset()
    45   test = [20,101]
    46   test_class = classify(test,group,labels,3)
    47   print (test_class)
    View Code
  • 相关阅读:
    .Spring事务管理
    什么叫事务;什么叫业务逻辑;什么叫持久化
    Hibernate基本应用01
    Maven整理
    责任链模式和观察者模式
    SpringBoot基础入门
    反射总结
    多线程
    IO流
    File类总结
  • 原文地址:https://www.cnblogs.com/yuyang81577/p/11359799.html
Copyright © 2011-2022 走看看