zoukankan      html  css  js  c++  java
  • 编写knn算法实现手写体识别

    • 一、首先学习学习knn算法。

    kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。

                         

    画个简单的图,假设其他类型的图案是有类别的,我们需要将中间的六边形进行归类,这是我们可以利用knn,计算它与其他图形的距离,取k值,决策它应该归类到哪一类中。

    看右上图,绿色圆要被决定赋予哪个类,是红色三角形还是蓝色四方形?如果K=3,由于红色三角形所占比例为2/3,绿色圆将被赋予红色三角形那个类,如果K=5,由于蓝色四方形比例为3/5,因此绿色圆被赋予蓝色四方形类。

    • 二、接下来代码实现knn算法:
    def knn(k,testdata,traindata,labels):
        traindatasize=traindata.shape[0]
        dif=tile(testdata,(traindatasize,1))-traindata
        sqdif=dif**2
        sumsqdif=sqdif.sum(axis=1)
        distance=sumsqdif**0.5
        sortdistance=distance.argsort()
        count={}
        for i in range(0,k):
            vote=labels[sortdistance[i]]
            count[vote]=count.get(vote,0)+1
        sortcount=sorted(count.items(),key=operator.itemgetter(1),reverse=True)
        return sortcount[0][0]

    knn算法步骤:

    1、处理数据

    2、数据向量化

    3、计算欧几里得距离

    4、根据距离进行分类

    参数   k用于改变误差率,testdata:测试数据集,traindata:训练数据集,labels:标签

    • 三、了解手写体识别

    我们通过画图,或者在纸上写上数字或字母,将照片进行处理,得到固定的照片规格,将照片转换为文本0,1表示的内容。例如下图:(图中内容68)训练集和测试集可以自己手写利用PIL库转化,也可以互联网上找。

    为了简单起见,固定图片的像素为32*32。

    为了保证结果的低误差率,可以将训练集数量设置多一点。

    • 四、将图片转换为0,1文本

    pip install pillow   安装对应库,PIL.Image处理图片,getpixel方法获取像素,判断像素的颜色,进行文本内容0,1的写入。(下面我对应图片没有设置像素,导致写入内容很多,在此只做简单思路分析)

    • 五、加载数据,将训练集(测试集)转化为数组
    def datatoarray(fname):
        arr=[]
        fh=open(fname)
        for i in range(0,32):
            thisline=fh.readline()
            for j in range(0,32):
                arr.append(int(thisline[j]))
        return arr
    • 六、建立一个函数取出对应手写体的名字(输入的参数是文件目录),从而建立label
    def seplabel(fname):
        filestr=fname.split(".")[0]
        labels=int(filestr.split("_")[0])
        return labels
    • 七、建立训练数据集
    def traindata():
        labels=[]
        trainfile=os.listdir("./traindata")
        num=len(trainfile)
        #像素32*32=1024
        #创建一个数组存放训练数据,行为文件总数,列为1024,为一个手写体的内容 zeros创建规定大小的数组
        trainarr=zeros((num,1024))
        for i in range(0,num):
            thisfname=trainfile[i]
            thislabel=seplabel(thisfname)
            labels.append(thislabel)
            trainarr[i]=datatoarray("./traindata/"+thisfname)
        return trainarr,labels
    • 八、用测试数据调用knn算法完成测试
    def datatest():
        trainarr,labels=traindata()
        testlist=os.listdir("./testdata")
        tnum=len(testlist)
        for i in range(tnum):
            thisname=testlist[i]
            testarr=datatoarray("./testdata/"+thisname)
            rknn=knn(k=3,testdata=testarr,traindata=trainarr,labels=labels)  
            print(str(thisname)+"  :  "+str(rknn))

    运行效果(冒号前是测试集的文件名,对应数字的第几个测试样本):可以看到基本上都能准确的将测试集中的手写体数字识别正确并归类,有少部分数字识别失败,将测试样本归类为其他内容。可以更改K值,来改变误差率。

  • 相关阅读:
    uva 10369 Arctic Network
    uvalive 5834 Genghis Khan The Conqueror
    uvalive 4848 Tour Belt
    uvalive 4960 Sensor Network
    codeforces 798c Mike And Gcd Problem
    codeforces 796c Bank Hacking
    codeforces 768c Jon Snow And His Favourite Number
    hdu 1114 Piggy-Bank
    poj 1276 Cash Machine
    bzoj 2423 最长公共子序列
  • 原文地址:https://www.cnblogs.com/hecxx/p/11959851.html
Copyright © 2011-2022 走看看