zoukankan      html  css  js  c++  java
  • knn手写识别

    import numpy as np
    import operator
    import os
    
    #KNN算法
    def knn(k,testdata,traindata,labels):#(k,测试样本,训练集,分类)
        traindatasize=traindata.shape[0]#行数
        #测试样本和训练集样本数可能不一样,因此需要将测试集样本数扩展成和训练集一样多
        #从行方向扩展 tile(a,(size,1))
        dif=np.tile(testdata,(traindatasize,1))-traindata
        #计算距离
        sqdif=dif**2
        sumsqdif=sqdif.sum(axis=1)
        distance=sumsqdif**0.5
    
        sortdistance=distance.argsort()#从小到大排列,结果返回元素位置
        count={}
        for i in range(k):
            vote=labels[sortdistance[i]]
            #统计每一类列样本的数量
            count[vote]=count.get(vote,0)+1
        sortcount=sorted(count.items(),key=operator.itemgetter(1),reverse=True)
        #取包含样本数量最多的那一类别
        return sortcount[0][0]
    
    
    #加载数据,将文件转化为数组形式
    def datatoarray(filename):
        arr=[]
        fh=open(filename)
        for i in range(32):
            thisline=fh.readline()
            for j in range(32):
                arr.append(int(thisline[j]))
        return arr
    
    #获取文件的lable
    def get_labels(filename):
        label=int(filename.split('_')[0])
        return label
    
    #建立训练数据
    def train_data():
        labels=[]
        trainlist=os.listdir('traindata/')
        num=len(trainlist)
        #长度1024(列),每一行存储一个文件
        #用一个数组存储所有训练数据,行:文件总数,列:1024
        trainarr=np.zeros((num,1024))
        for i in range(num):
            thisfile=trainlist[i]
            labels.append(get_labels(thisfile))
            trainarr[i,:]=datatoarray("traindata/"+thisfile)
        return trainarr,labels
    
    #用测试数据调用KNN算法进行测试
    def datatest():
        a=[]#准确结果
        b=[]#预测结果
        traindata,labels=train_data()
        testlist=os.listdir('testdata/')
        fh=open('result_knn.csv','a')
        for test in testlist:
            testfile='testdata/'+test
            testdata=datatoarray(testfile)
            result=knn(3,testdata,traindata,labels)
            #将预测结果存在文本中
            fh.write(test+'-----------'+str(result)+'
    ')
            a.append(int(test.split('_')[0]))
            b.append(int(result))
        fh.close()
        return a,b
    
    if __name__=='__main__':
        a,b=datatest()
        num=0
        for i in range(len(a)):
            if(a[i]==b[i]):
                num+=1
            else:
                print("预测失误:",a[i],"预测为",b[i])
        print("测试样本数为:",len(a))
        print("预测成功数为:",num)
        print("模型准确率为:",num/len(a))
  • 相关阅读:
    JAVA SSH 框架介绍
    Web开发者不可不知的15条编码原则
    全选,反选,全不选
    Python函数
    Python变量解析
    Python输入/输出语句
    Python程序基本架构
    Python开发环境安装
    java事件
    测试博客
  • 原文地址:https://www.cnblogs.com/dzhou/p/9480691.html
Copyright © 2011-2022 走看看