zoukankan      html  css  js  c++  java
  • KNN 算法-实战篇-如何识别手写数字

    公号:码农充电站pro
    主页:https://codeshellme.github.io

    上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字

    1,手写数字数据集

    手写数字数据集是一个用于图像处理的数据集,这些数据描绘了 [0, 9] 的数字,我们可以用KNN 算法来识别这些数字。

    MNIST 是完整的手写数字数据集,其中包含了60000 个训练样本和10000 个测试样本。

    sklearn 中也有一个自带的手写数字数据集

    • 共包含 1797 个数据样本,每个样本描绘了一个 8*8 像素的 [0, 9] 的数字。
    • 每个样本由 65 个数字组成:
      • 前 64 个数字是特征数据,特征数据的范围是 [0, 16]
      • 最后一个数字是目标数据,目标数据的范围是 [0, 9]

    我们抽出 5 个样本来看下:

    0,0,5,13,9,1,0,0,0,0,13,15,10,15,5,0,0,3,15,2,0,11,8,0,0,4,12,0,0,8,8,0,0,5,8,0,0,9,8,0,0,4,11,0,1,12,7,0,0,2,14,5,10,12,0,0,0,0,6,13,10,0,0,0,0
    0,0,0,12,13,5,0,0,0,0,0,11,16,9,0,0,0,0,3,15,16,6,0,0,0,7,15,16,16,2,0,0,0,0,1,16,16,3,0,0,0,0,1,16,16,6,0,0,0,0,1,16,16,6,0,0,0,0,0,11,16,10,0,0,1
    0,0,0,4,15,12,0,0,0,0,3,16,15,14,0,0,0,0,8,13,8,16,0,0,0,0,1,6,15,11,0,0,0,1,8,13,15,1,0,0,0,9,16,16,5,0,0,0,0,3,13,16,16,11,5,0,0,0,0,3,11,16,9,0,2
    0,0,7,15,13,1,0,0,0,8,13,6,15,4,0,0,0,2,1,13,13,0,0,0,0,0,2,15,11,1,0,0,0,0,0,1,12,12,1,0,0,0,0,0,1,10,8,0,0,0,8,4,5,14,9,0,0,0,7,13,13,9,0,0,3
    0,0,0,1,11,0,0,0,0,0,0,7,8,0,0,0,0,0,1,13,6,2,2,0,0,0,7,15,0,9,8,0,0,5,16,10,0,16,6,0,0,4,15,16,13,16,1,0,0,0,0,3,15,10,0,0,0,0,0,2,16,4,0,0,4
    

    使用该数据集,需要先加载:

    >>> from sklearn.datasets import load_digits
    >>> digits = load_digits()
    

    查看第一个图像数据:

    >>> digits.images[0]
    array([[ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.],
           [ 0.,  0., 13., 15., 10., 15.,  5.,  0.],
           [ 0.,  3., 15.,  2.,  0., 11.,  8.,  0.],
           [ 0.,  4., 12.,  0.,  0.,  8.,  8.,  0.],
           [ 0.,  5.,  8.,  0.,  0.,  9.,  8.,  0.],
           [ 0.,  4., 11.,  0.,  1., 12.,  7.,  0.],
           [ 0.,  2., 14.,  5., 10., 12.,  0.,  0.],
           [ 0.,  0.,  6., 13., 10.,  0.,  0.,  0.]])
    

    我们可以用 matplotlib 将该图像画出来:

    >>> import matplotlib.pyplot as plt
    >>> plt.imshow(digits.images[0])
    >>> plt.show()
    

    画出来的图像如下,代表 0

    在这里插入图片描述

    2,sklearn 对 KNN 算法的实现

    sklearn 库的 neighbors 模块实现了KNN 相关算法,其中:

    • KNeighborsClassifier 类用于分类问题
    • KNeighborsRegressor 类用于回归问题

    这两个类的构造方法基本一致,这里我们主要介绍 KNeighborsClassifier 类,原型如下:

    KNeighborsClassifier(
    	n_neighbors=5, 
    	weights='uniform', 
    	algorithm='auto', 
    	leaf_size=30, 
    	p=2, 
    	metric='minkowski', 
    	metric_params=None, 
    	n_jobs=None, 
    	**kwargs)
    

    来看下几个重要参数的含义:

    • n_neighbors:即 KNN 中的 K 值,一般使用默认值 5。
    • weights:用于确定邻居的权重,有三种方式:
      • weights=uniform,表示所有邻居的权重相同。
      • weights=distance,表示权重是距离的倒数,即与距离成反比。
      • 自定义函数,可以自定义不同距离所对应的权重,一般不需要自己定义函数。
    • algorithm:用于设置计算邻居的算法,它有四种方式:
      • algorithm=auto,根据数据的情况自动选择适合的算法。
      • algorithm=kd_tree,使用 KD 树 算法。
        • KD 树是一种多维空间的数据结构,方便对数据进行检索。
        • KD 树适用于维度较少的情况,一般维数不超过 20,如果维数大于 20 之后,效率会下降。
      • algorithm=ball_tree,使用球树算法。
        • KD 树一样都是多维空间的数据结构。
        • 球树更适用于维度较大的情况。
      • algorithm=brute,称为暴力搜索
        • 它和 KD 树相比,采用的是线性扫描,而不是通过构造树结构进行快速检索。
        • 缺点是,当训练集较大的时候,效率很低。
      • leaf_size:表示构造 KD 树球树时的叶子节点数,默认是 30。
        • 调整 leaf_size 会影响树的构造和搜索速度。

    3,构造 KNN 分类器

    首先加载数据集:

    from sklearn.datasets import load_digits
    
    digits = load_digits()
    data = digits.data     # 特征集
    target = digits.target # 目标集
    

    将数据集拆分为训练集(75%)和测试集(25%):

    from sklearn.model_selection import train_test_split
    
    train_x, test_x, train_y, test_y = train_test_split(
        data, target, test_size=0.25, random_state=33)
    

    构造KNN 分类器:

    from sklearn.neighbors import KNeighborsClassifier
    
    # 采用默认参数
    knn = KNeighborsClassifier() 
    

    拟合模型:

    knn.fit(train_x, train_y) 
    

    预测数据:

    predict_y = knn.predict(test_x) 
    

    计算模型准确度:

    from sklearn.metrics import accuracy_score
    
    score = accuracy_score(test_y, predict_y)
    print score # 0.98
    

    最终计算出来模型的准确度是 98%,准确度还是不错的。

    4,总结

    本篇文章使用KNN 算法处理了一个实际的分类问题,主要介绍了以下几点:

    • 介绍了sklearn 中自带的手写数字集,并用 matplotlib 模块画出了数字图像。
    • 介绍了sklearnneighbors.KNeighborsClassifier 类的用法。
    • 使用 KNeighborsClassifier 来识别手写数字。

    (本节完。)


    推荐阅读:

    KNN 算法-理论篇-如何给电影进行分类

    决策树算法-理论篇-如何计算信息纯度

    决策树算法-实战篇-鸢尾花及波士顿房价预测

    朴素贝叶斯分类-理论篇-如何通过概率解决分类问题

    朴素贝叶斯分类-实战篇-如何进行文本分类


    欢迎关注作者公众号,获取更多技术干货。

    码农充电站pro

  • 相关阅读:
    机器学习——ALS算法
    机器学习——Kmeans算法
    机器学习——欧式距离和余弦距离
    《JAVA策略模式》
    POSTGRESQL 数据库导入导出
    SpringBoot解决前后端全局跨域问题WebMvcConfigurer
    java读取json文件进行解析,String转json对象
    docker: Error response from daemon: Conflict. The container name "/mysql8.0" is already
    学习笔记:shell 中 [-eq] [-ne] [-gt] [-lt] [ge] [le]
    linux 判断一个用户是否存在
  • 原文地址:https://www.cnblogs.com/codeshell/p/14077625.html
Copyright © 2011-2022 走看看