zoukankan      html  css  js  c++  java
  • python实现KNN,识别手写数字

    写了识别手写数字的KNN算法,如下图所示。参考链接http://blog.csdn.net/april_newnew/article/details/44176059。

    # -*- coding: utf-8 -*-
    
    import numpy as np
    import pandas as pd
    import os
    def readtxt(filename):
        text=[]
        f = open(filename,'r',encoding='utf-8')
        for line in f.readlines():
            text.append(line)
        txt = list(text)
        txt=np.array(txt,dtype='float')
        txt = txt.tolist()
        return txt
    
    def readdata(rootfile):
        data = []
        label = []
        for root,dirs,files in os.walk(rootfile):
            for name in files:
                filename = root +'\'+name
                txt = readtxt(filename)
                data.append(txt)
                label1 = name.split('_')[0]
                label.append(label1)
        data = pd.DataFrame(data)
        return data,label
    
    def KNN(traindata,trainlabel,testdatai,K):
        length = len(traindata)
        newtest = np.tile(testdatai, (length,1))
        newtest = pd.DataFrame(newtest)
        diff = newtest - traindata
        diff = diff**2
        cha = diff.sum(axis=1)
        cha = cha**0.5
        result = pd.DataFrame({'label':trainlabel,
                           'cha':cha})
        labels = result.sort_values(by='cha')[:K]
        frequent =labels.groupby(labels['label']).size()
        labely = frequent.argmax()
        return labely
            
    def test(trainfile,testfile,K):
        result = []
        traindata, trainlabel= readdata(trainfile)
        testdata, testlabel = readdata(testfile)
        for i in range(len(testdata)):
            labely = KNN(traindata,trainlabel,testdata.loc[i,:],K)
            result.append(labely)
        tongji  = pd.DataFrame({'result':result,'testlabel':testlabel})
        accuary = len(tongji[tongji['result']==tongji['testlabel']])/len(result)
        return result,accuary
        
    trainfile=r'E:	rainingDigits'
    testfile=r'E:	estDigits'
    K=3    
    result, accuary= test(trainfile,testfile,K)
                

    注:训练数据集有2,210条记录,测试数据有670条。准确率并不高,只有0.45。目前不知道为什么,以后多学习,争取优化代码。

  • 相关阅读:
    还是java中的编码问题
    java restful api
    编码方式
    LinkedHash
    Zoj 2562 More Divisors (反素数)
    spark复习总结03
    spark复习总结02
    spark复习总结01
    使用二进制解决一个字段代表多个状态的问题
    spark性能调优05-troubleshooting处理
  • 原文地址:https://www.cnblogs.com/chenyaling/p/7244266.html
Copyright © 2011-2022 走看看