zoukankan      html  css  js  c++  java
  • KNN--用于手写数字识别

    优点:精度高,对异常值不敏感,无数据输入假定
    缺点:计算复杂度高,空间复杂度高
    适用数据范围:数值型和标称型
     
    一般流程:
        (1). 收集数据(网络抓取)
        (2).处理数据,将数据处理成结构化的数据格式。
        (3).分析数据
        (4).测试算法(主要是计算模型的出错率)
        (5).使用算法,
     
    K-近邻算法采用测量不同特征值之间的距离的方法进行分类
     
    工作原理是:存在一个训练样本集,且样本集中每个数据都存在标签(与分类的对应关系)。
    当输入没有标签的新数据后,将新数据的每个特征与样本集中的数据对应的特征进行比较,
    然后算法提取训练样本集中前k个最相似的数据的分类标签,且k不大于20 。
    选择最相似数据中出现次数最多的分类,作为新数据的分类。
     
     
    数据归一化的作用,只用在特征数据相差较大且同等重要的条件下
     

     
    上面方程中数字差值最大的属性对计算结果的影响最大,仅仅是因为飞行常客里程数远大于其他特征值。然而我们认为这三种特征同样重要,因此作为三个等权重的特征 
     
    直接上代码:
    from numpy import *
    import matplotlib.pyplot as plot
    import operator
    from os import listdir


    def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    # 距离计算公式
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5

    # 距离从大到小排序,返回距离的序号
    sortedDistIndicies = distances.argsort()
    # 声明一个空的字典,用于存放标签
    classCount={}
    for i in range(k):
    # sortedDistIndicies[0]返回的是距离最小的数据样本的序号
    # labels[sortedDistIndicies[0]]距离最小的数据样本的标签
    voteIlabel = labels[sortedDistIndicies[i]]
    classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    # 给该字典排序,sortedClassCount[0][0]是K中支持的标签数最大的
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    print(sortedClassCount[0][0])
    return sortedClassCount[0][0]


    # 创建数据
    def createDataSet():
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = ['A','A','B','B']
    return group, labels

    # 画图
    def draw(xs,ys):
    fig = plot.figure()
    # 将画布分割成1行1列,图像画在从左到右从上到下的第1块
    # 设置画布的大小与图像的位置
    ax = fig.add_subplot(221)
    # ax.scatter(xs, ys)的两个参数分别是所有点的x坐标,所有点的y坐标
    ax.scatter(xs,ys)
    plot.show()

    def firstTest():
    test1 = (1.0, 1.2)
    test2 = (0.0, 0.4)
    dataset, labels = createDataSet()
    conclusion1 = classify0(test1, dataset, labels, 3)
    conclusion2 = classify0(test2, dataset, labels, 3)
    print(str(test1) + "分类后的结果是属于" + conclusion1 + "类")
    print(str(test2) + "分类后的结果是属于" + conclusion2 + "类")
    # 将32*32的矩阵读为1*1024
    def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
    lineStr = fr.readline()
    for j in range(32):
    returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

    def handwritingClassTest():
    hwLabels = []
    # 获得训练样本数据集
    trainingFileList = listdir('digits/trainingDigits')
    # 样本数的个数
    m = len(trainingFileList)
    # 返回m行1024列的矩阵数据
    trainingMat = zeros((m, 1024))
    # 文件名下划线_左边的数字是标签
    for i in range(m):
    fileNameStr = trainingFileList[i]
    fileStr = fileNameStr.split(".")[0]
    # 分类标签
    classNumStr = int(fileStr.split('_')[0])
    hwLabels.append(classNumStr)
    trainingMat[i, :] = img2vector('digits/trainingDigits/%s' % fileNameStr)
    testFileList = listdir('digits/testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
    fileNameStr = testFileList[i]
    fileStr = fileNameStr.split('.')[0] # take off .txt
    classNumStr = int(fileStr.split('_')[0])
    vectorUnderTest = img2vector('digits/testDigits/%s' % fileNameStr)
    classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
    print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
    if (classifierResult != classNumStr): errorCount += 1.0
    print(" the total number of errors is: %d" % errorCount)
    print(" the total error rate is: %f" % (errorCount / float(mTest)))

    # 主函数调用模块函数
    if __name__ == "__main__":
    # group,label = createDataSet()
    # # group[:, 0] 所有行的第0列
    # draw(group[:, 0], group[:, 1])
    # # print(group[:, 0])
    # firstTest()
    handwritingClassTest()

    训练数据集合测试集的数据:https://gitee.com/lcl1993213/plist
  • 相关阅读:
    阿里云服务器无法通过浏览器访问
    浅谈java枚举类
    WebService基础学习
    cxf报错 Cannot find any registered HttpDestinationFactory from the Bus.
    Mybatis JdbcType与Oracle、MySql 数据类型对应关系
    plsql + instantclient 连接oracle ( 超简单)
    Shiro框架
    Java中 实体类 VO、 PO、DO、DTO、 BO、 QO、DAO、POJO的概念
    POI 生成 word 文档 简单版(包括文字、表格、图片、字体样式设置等)
    web.xml 配置文件 超详细说明!!!
  • 原文地址:https://www.cnblogs.com/lcl15/p/7987115.html
Copyright © 2011-2022 走看看