zoukankan      html  css  js  c++  java
  • 后端程序员之路 13、使用KNN进行数字识别

    尝试一些用KNN来做数字识别,测试数据来自:
    MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
    http://yann.lecun.com/exdb/mnist/

    1、数据
    将位图转为向量(数组),k尝试取值3-15,距离计算采用欧式距离。
    d(x,y)=sqrt{sum_{i=1}^{n}(x_i-y_i)^2}

    2、测试
    调整k的取值和基础样本数量,测试得出k取值对识别正确率的影响,以及分类识别的耗时。

    如何用python解析mnist图片 - 海上扬凡的博客 - 博客频道 - CSDN.NET
    http://blog.csdn.net/u014046170/article/details/47445919

    # -*- coding: utf-8 -*-
    """
    Created on Wed Mar 08 14:38:15 2017

    @author: zapline<278998871@qq.com>
    """

    import struct
    import os
    import numpy

    def read_file_data(filename):
        f = open(filename, 'rb')
        buf = f.read()
        f.close()
        return buf

    def loadImageDataSet(filename):
        index = 0
        buf = read_file_data(filename)
        magic, images, rows, columns = struct.unpack_from('>IIII' , buf , index)
        index += struct.calcsize('>IIII')
        data = numpy.zeros((images, rows * columns))
        for i in xrange(images):
            imgVector = numpy.zeros((1, rows * columns)) 
            for x in xrange(rows):
                for y in xrange(columns):
                    imgVector[0, x * columns + y] = int(struct.unpack_from('>B', buf, index)[0])
                    index += struct.calcsize('>B')
            data[i, :] = imgVector
        return data

    def loadLableDataSet(filename):
        index = 0
        buf = read_file_data(filename)
        magic, images = struct.unpack_from('>II' , buf , index)
        index += struct.calcsize('>II')
        data = []
        for i in xrange(images):
            lable = int(struct.unpack_from('>B', buf, index)[0])
            index += struct.calcsize('>B')
            data.append(lable)
        return data

    def loadDataSet():
        path = "D:\kingsoft\ml\dataset\"
        trainingImageFile = path + "train-images.idx3-ubyte"
        trainingLableFile = path + "train-labels.idx1-ubyte"
        testingImageFile = path + "t10k-images.idx3-ubyte"
        testingLableFile = path + "t10k-labels.idx1-ubyte"
        train_x = loadImageDataSet(trainingImageFile)
        train_y = loadLableDataSet(trainingLableFile)
        test_x = loadImageDataSet(testingImageFile)
        test_y = loadLableDataSet(testingLableFile)
        return train_x, train_y, test_x, test_y


    # -*- coding: utf-8 -*-
    """
    Created on Wed Mar 08 14:35:55 2017

    @author: zapline<278998871@qq.com>
    """

    import numpy

    def kNNClassify(newInput, dataSet, labels, k):
        numSamples = dataSet.shape[0]
        diff = numpy.tile(newInput, (numSamples, 1)) - dataSet
        squaredDiff = diff ** 2
        squaredDist = numpy.sum(squaredDiff, axis = 1)
        distance = squaredDist ** 0.5
        sortedDistIndices = numpy.argsort(distance)

        classCount = {}
        for i in xrange(k):
            voteLabel = labels[sortedDistIndices[i]]
            classCount[voteLabel] = classCount.get(voteLabel, 0) + 1

        maxCount = 0
        for key, value in classCount.items():
            if value > maxCount:
                maxCount = value
                maxIndex = key
        return maxIndex


    # -*- coding: utf-8 -*-
    """
    Created on Wed Mar 08 14:39:21 2017

    @author: zapline<278998871@qq.com>
    """

    import dataset
    import knn

    def testHandWritingClass():
        print "step 1: load data..."
        train_x, train_y, test_x, test_y = dataset.loadDataSet()

        print "step 2: training..."
        pass

        print "step 3: testing..."
        numTestSamples = test_x.shape[0]
        matchCount = 0
        for i in xrange(numTestSamples):
            predict = knn.kNNClassify(test_x[i], train_x, train_y, 3)
            if predict == test_y[i]:
                matchCount += 1
        accuracy = float(matchCount) / numTestSamples

        print "step 4: show the result..."
        print 'The classify accuracy is: %.2f%%' % (accuracy * 100)
     
    testHandWritingClass()
    print "game over"

    总结:上述代码跑起来比较慢,但是在train数据够多的情况下,准确率不错

  • 相关阅读:
    Java内存分配及垃圾回收机制
    《当你的才华还撑不起你的梦想时》读后感
    Java线程池入门必备
    单例模式的那些事
    idea超炫的自定义模板
    布隆过滤器概念和原理
    MessageDigest
    java zip 压缩与解压
    WebStorm 注册码
    taobao-pamirs-proxycache开源缓存代理框架实现原理剖析
  • 原文地址:https://www.cnblogs.com/zapline/p/6546574.html
Copyright © 2011-2022 走看看