zoukankan      html  css  js  c++  java
  • k近邻算法

    介绍

    k近邻算法(KNN)属于监督学习的分类算法,通过测量不同特征值之间的距离进行分类,算法过程如下

    • 计算数据点与已知数据集中每个点的距离
    • 对距离从小到大进行排序
    • 选取前k个距离值
    • 确定前k个距离值所在类别的出现的概率
    • 将前k个点出现频率最高的类别作为当前数据的预测分类

    主要代码如下

    def classfiy(inData, dataSet, labels, k):
        dataSize = dataSet.shape[0]  # 得到数组的行维度,即数据的个数
        # 先通过tile将输入的数据扩展为与dataSet相同维度的数组,再通过距离公式计算距离
        distance = (((tile(inData, (dataSize, 1)) - dataSet) ** 2).sum(axis=1)) ** 0.5
        sortIndex = distance.argsort()  # 返回数组值从小到大的索引值
        classCount = {}
        for i in range(k):  # 只对前k个计数
            headLabel = labels[sortIndex[i]]
            classCount[headLabel] = classCount.get(headLabel, 0) + 1  # 统计前k个中出现标签的次数
        # 对字典按照第二个值(即出现的次数)进行排序,用reverse指定从大到小排
        sortCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
        return sortCount[0][0]  # 返回第一个的标签
    

    其中距离计算,通过公式,如((x_{1},y_{1})(x_{2},y_{2}))两点的距离d为(d=sqrt{(x_{1}-x_{2})^2+(y_{1}-y_{2})^2})

    用KNN识别数字图片中的数字

    只是个玩具程序

    收集数据

    每个数字准备了10张图片,分别存在digit中的以各个数字命名的文件夹下

    又为每个数据准备了5张图片,以同样的规则存在digit2的各个文件夹下

    准备数据

    缩放图像
    采用了pillow中的resize函数,同一将图像缩放为50*50
    newImg = img.resize((50, 50))
    二值化图像
    开始想直接通过convet('1')直接将图像二值化,但出现了很多噪音
    所以通过以下程序将图像二值化。其中230为设定的阀值,多次尝试,发现230效果较好

        for i in range(rows):
            for j in range(cols):
                if (imgArray[i, j] <= 230):
                    imgArray[i, j] = 0
                else:
                    imgArray[i, j] = 255
    

    转化为一维向量
    将读取的处理后的图片的像素值转化为一维向量

    测试

    通过读取测试集中的数据,进行预测,和实际的类别比对,查看正确率

    程序

    from PIL import Image
    from numpy import *
    import os
    import operator
    
    #缩放为相同大小
    def toSame(img):
        newImg = img.resize((50, 50))
        return newImg
    
    #二值化处理
    def toBinarry(img):
        imgArray = array(img)
        rows, cols = imgArray.shape
        for i in range(rows):
            for j in range(cols):
                if (imgArray[i, j] <= 230):
                    imgArray[i, j] = 0
                else:
                    imgArray[i, j] = 255
        return imgArray
    
    #读取每个文件夹下的每张图片
    def readImage(filePath):
        dataList = []
        labels = []
        for i in range(10):
            imagePath = filePath + '/' + str(i)
            files = os.listdir(imagePath)
            for j in files:
                labels.append(j.split('_')[0])#因为每张图片采用‘数字_第几张的命名方式’,所以通过下横线分割,取得前面的作为图片的分类标签
                img = Image.open(imagePath + '/' + j).convert('L')#先灰度化处理
                imgArray = toBinarry(toSame(img))
                dataList.append(imgArray.ravel())#转变为一维后加入列表
        dataSet = array(dataList)
        return dataSet, labels
    
    #分类算法
    def classfiy(inData, dataSet, labels, k):
        dataSize = dataSet.shape[0]  # 得到数组的行维度,即数据的个数
        # 先通过tile将输入的数据扩展为与dataSet相同维度的数组,再通过距离公式计算距离
        distance = (((tile(inData, (dataSize, 1)) - dataSet) ** 2).sum(axis=1)) ** 0.5
        sortIndex = distance.argsort()  # 返回数组值从小到大的索引值
        classCount = {}
        for i in range(k):  # 只对前k个计数
            headLabel = labels[sortIndex[i]]
            classCount[headLabel] = classCount.get(headLabel, 0) + 1  # 统计前k个中出现标签的次数
        # 对字典按照第二个值(即出现的次数)进行排序,用reverse指定从大到小排
        sortCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
        return sortCount[0][0]  # 返回第一个的标签
    
    # 进行测试
    dataSet, labels = readImage('./digit')
    dataSet2, labels2 = readImage('./digit2')
    n = 0
    for i in range(len(dataSet2)):
        predict = classfiy(dataSet2[i], dataSet, labels, 10)
        print(predict + ' ' + labels2[i])
        if (predict == labels2[i]):
            n = n + 1
    # 查看准确率
    print(n / len(dataSet2))
    

    运行结果


    发现准确率只有0.62

    总结

    • 准确率如此低,可能是数据不足,也可能对图像处理不好。在二值化时,效果其实并不完美。也可能需要对图像进行一些裁剪。在二值化时,本程序也只适合一些浅色底子的数字图片
    • 采用不同的k,预测的效果也是不同,也需要找到一个最佳的k

    其它

    • 在处理数据时,通常用到的归一化
    def toNormal(dataSet):
        # 归一化
        min = dataSet.min(0)
        max = dataSet.max(0)
        # 公式normal=(x-min)/(max-min)
        normalArray = (dataSet - tile(min, (dataSet.shape[0], 1))) / tile(max - min, (dataSet.shape[0], 1))
        return normalArray
    
    
    def toClear(imgArray):
        rows, cols = imgArray.shape
        for y in range(1, cols - 1):
            for x in range(1, rows - 1):
                count = 0
                if imgArray[x, y - 1] == 255:  # 上
                    count = count + 1
                if imgArray[x, y + 1] == 255:  # 下
                    count = count + 1
                if imgArray[x - 1, y] == 255:  # 左
                    count = count + 1
                if imgArray[x + 1, y] == 255:  # 右
                    count = count + 1
                if imgArray[x - 1, y - 1] == 255:  # 左上
                    count = count + 1
                if imgArray[x - 1, y + 1] == 255:  # 左下
                    count = count + 1
                if imgArray[x + 1, y - 1] == 255:  # 右上
                    count = count + 1
                if imgArray[x + 1, y + 1] == 255:  # 右下
                    count = count + 1
                if count > 4:
                    imgArray[x, y] = 255
        return imgArray
    
  • 相关阅读:
    CentOS 7.X 关闭SELinux
    删除或重命名文件夹和文件的方法
    centos7-每天定时备份 mysql数据库
    centos7 tar.gz zip 解压命令
    MySQL5.6/5.7/8.0版本授权用户远程连接
    下载CentOS7系统
    使用js实现tab页签切换效果
    sql优化常用的几种方法
    mysql 多表联查的快速查询(索引)
    【图论】强连通分量+tarjan算法
  • 原文地址:https://www.cnblogs.com/Qi-Lin/p/12247163.html
Copyright © 2011-2022 走看看