zoukankan      html  css  js  c++  java
  • k近邻算法的Python实现

    k近邻算法的Python实现

    0. 写在前面

    这篇小教程适合对Python与NumPy有一定了解的朋友阅读,如果在阅读本文的源代码时感到吃力,请及时参照相关的教程或者文档。

    1. 算法原理

    k近邻算法(k Nearest Neighbor)可以简称为kNN。kNN是一个简单直观的算法,也是机器学习从业者入门首选的算法。先看一个简单的应用场景。

    小例子

    设有下表,命名为为表1

    电影名称 打斗镜头数量 接吻镜头数量 电影类型
    foo1 3 104 爱情片
    foo2 2 100 爱情片
    foo3 1 81 爱情片
    foo4 101 10 动作片
    foo5 99 5 动作片
    foo6 98 2 动作片

    一个朴素的愿望是,能够根据打斗镜头与接吻镜头的数量来推测一部电影是属于爱情片还是动作片。具体而言,如果有一部电影的相关信息如下,命名为表2:

    电影名称 打斗镜头数量 接吻镜头数量
    foo7 18 90

    我们能否给出这部电影的类型?

    解决方案

    表1可以抽象为一个矩阵A与一个列向量x如下:

    矩阵A

    foo1	3	104
    foo2	2	100
    foo3	1	81
    foo4	101	10
    foo5	99	5
    foo6	98	2
    

    列向量x

    爱情片
    爱情片
    爱情片
    动作片
    动作片
    动作片
    

    表2可以抽象为一个行向量a如下:

    行向量a

    foo7 18 90
    

    显然,可以求矩阵A中每一个行向量与行向量a欧式距离(本例中计算欧式距离时只考虑打斗镜头数量与接吻镜头数量两个分量),并按照距离由小到大排序,结果如下表,命名为表3:

    电影名称 与未知电影之的距离
    foo2 18.7
    foo3 19.2
    foo1 20.5
    foo4 115.3
    foo5 117.4
    foo6 118.9

    此时选择前k个距离最小的电影及其所属的类型,结果如下表,命名为表4:

    电影名称 类型
    foo2 爱情片
    foo3 爱情片
    foo4 爱情片

    找出表4中出现次数最多的类型——“爱情片”,即kNN认为行向量a所属的类型为爱情片。

    2. Python实现

    代码的核心部分是如下函数,将其保存在文件中mykNN.py中。

    import numpy as np
    import operator as op
    from collections import defaultdict
    
    def classify(vec, dataSet, labels, k):
        """
        要求dataSet为NumPy的array类型
        vec: 参照行向量a
        dataSet: 参照矩阵A
        labels: 参照列向量x
        k: kNN中选择前k小的行
        """
    
        size = dataSet.shape[0]
        assert size == len(labels) #断言,确保输入正确
        tmp = (dataSet - vec)**2 #使用了NumPy的广播机制
        tmp = tmp.sum(axis=1)
        tmp = tmp.argsort()
    
        tmpDict = defaultdict(int) #简化用于分组的代码
        for i in range(k):
            tmpDict[labels[tmp[i]]] += 1
    
        return max(tmpDict.items(),key=op.itemgetter(1))[0]
    
    

    3. 练手案例

    我们使用第2小节的代码解决第1小节的问题。下面的代码文件保存为test.py,请确保test.py与mykNN.py文件位于同一个路径下。

    import numpy as np
    import mykNN as knn
    
    if __name__ == "__main__":
        dataSet = np.array([
            [3, 104],
            [2, 100],
            [1, 81],
            [101, 10],
            [99, 5],
            [98, 2]
        ])
        labels = ["爱情片", "爱情片", "爱情片",
                  "动作片", "动作片", "动作片"]
        k = 3
        vec = [18, 90]
        res = knn.classify(vec,dataSet,labels,k)
        print(res)
    
    

    4. 补充说明

    真实的分类任务不会像我们的案例那样简单。

    • 一般来说,第3小节中的dataSet与labels都会放在文件或者数据库中,并且未必是NumPy可以处理的数据类型。这时需要增加读文件或者读数据库并解析转换数据的一系列代码。

    • 有时需要考虑对表格的不同字段归一化的问题。

    • 以数据驱动的应用的开发需要关注kNN算法的正确率,这时需要增加判断正确率或者错误率的代码。

  • 相关阅读:
    Tomcat性能优化总结
    shell 服务器监控 cpu 和 java 占用 CPU 脚本
    编写shell时,遇到let: not found错误及解决办法
    Studio 3T 破解 mogodb
    nginx/iptables动态IP黑白名单实现方案
    创业公司这两年
    致所有的开发者们
    如何成为一名全栈开发工程师
    谈谈在创业公司的几点感触
    推荐阅读《赢在下班后》
  • 原文地址:https://www.cnblogs.com/pkuimyy/p/11656135.html
Copyright © 2011-2022 走看看