zoukankan      html  css  js  c++  java
  • 支持向量机-手写识别问题

    关于手写识别问题,用KNN分类效果也不错,但是KNN需要保留所有的训练样本。而对于支持向量机只需保留边界的支持向量。

    因为有些函数代码与上一篇博文(支持向量机-在复杂数据上应用核函数)的一致。建议只需将本节代码放入上篇代码之后即可。

    1. 将图像转换成向量

    def img2vector(filename):
        returnVect = zeros((1, 1024))
        fr = open(filename)
        for i in range(32):
            lineStr = fr.readline()
            for j in range(32):
                returnVect[0, 32 * i + j] = int(lineStr[j])
        return returnVect

    2. 加载图像

    def loadImages(dirName):
        from os import listdir
        hwLabels = []
        print(dirName)
        trainingFileList = listdir(dirName)  # load the training set
        m = len(trainingFileList)
        trainingMat = zeros((m, 1024))
        for i in range(m):
            fileNameStr = trainingFileList[i]
            fileStr = fileNameStr.split('.')[0]  # take off .txt
            classNumStr = int(fileStr.split('_')[0])
            if classNumStr == 9:
                hwLabels.append(-1)
            else:
                hwLabels.append(1)
            trainingMat[i, :] = img2vector('%s/%s' % (dirName, fileNameStr))
        return trainingMat, hwLabels

    3. 测试数字分类

    def testDigits(kTup=('rbf', 10)):
    
        # 1. 导入训练数据
        dataArr, labelArr = loadImages('F:/迅雷下载/machinelearninginaction/Ch06/trainingDigits/')
        b, alphas = smoP(dataArr, labelArr, 200, 0.0001, 10000, kTup)
        datMat = mat(dataArr)
        labelMat = mat(labelArr).transpose()
        svInd = nonzero(alphas.A > 0)[0]
        sVs = datMat[svInd]
        labelSV = labelMat[svInd]
        print("there are %d Support Vectors" % shape(sVs)[0])
        m, n = shape(datMat)
        errorCount = 0
        for i in range(m):
            kernelEval = kernelTrans(sVs, datMat[i, :], kTup)
            # 1*m * m*1 = 1*1 单个预测结果
            predict = kernelEval.T * multiply(labelSV, alphas[svInd]) + b
            if sign(predict) != sign(labelArr[i]): errorCount += 1
        print("the training error rate is: %f" % (float(errorCount) / m))
    
        # 2. 导入测试数据
        dataArr, labelArr = loadImages('F:/迅雷下载/machinelearninginaction/Ch06/testDigits/')
        errorCount = 0
        datMat = mat(dataArr)
        labelMat = mat(labelArr).transpose()
        m, n = shape(datMat)
        for i in range(m):
            kernelEval = kernelTrans(sVs, datMat[i, :], kTup)
            predict = kernelEval.T * multiply(labelSV, alphas[svInd]) + b
            if sign(predict) != sign(labelArr[i]): errorCount += 1
        print("the test error rate is: %f" % (float(errorCount) / m))

     支持向量机是一个二分类器,分类结果不是+1就是-1。我们这里1的标签为+1,9的标签为-1,其他的数字都被去掉了。

    testDigits(('rbf', 20))
    F:/迅雷下载/machinelearninginaction/Ch06/trainingDigits/
    iteration number: 1
    iteration number: 2
    iteration number: 3
    iteration number: 4
    iteration number: 5
    iteration number: 6
    iteration number: 7
    iteration number: 8
    there are 67 Support Vectors
    the training error rate is: 0.000000
    F:/迅雷下载/machinelearninginaction/Ch06/testDigits/
    the test error rate is: 0.010753
    testDigits(('rbf', 50))
    F:/迅雷下载/machinelearninginaction/Ch06/trainingDigits/
    iteration number: 1
    iteration number: 2
    iteration number: 3
    iteration number: 4
    iteration number: 5
    iteration number: 6
    there are 34 Support Vectors
    the training error rate is: 0.014925
    F:/迅雷下载/machinelearninginaction/Ch06/testDigits/
    the test error rate is: 0.010753
    testDigits(('rbf', 10))
    F:/迅雷下载/machinelearninginaction/Ch06/trainingDigits/
    iteration number: 1
    iteration number: 2
    iteration number: 3
    iteration number: 4
    iteration number: 5
    iteration number: 6
    there are 118 Support Vectors
    the training error rate is: 0.000000
    F:/迅雷下载/machinelearninginaction/Ch06/testDigits/
    the test error rate is: 0.010753

    可以尝试,当径向基核函数的参数σ取10左右,可以得到最小的测试错误
  • 相关阅读:
    CloudStack 实现VM高可用特性
    cloudstack基础知识
    cloudstack4.5私有云集群规划与安装
    小心了,这个设置会导致你的vm重启时被强制重装系统!
    CloudStack名词解释
    javatoexe之exe4j和innosetup打包jar
    oracle之partition by与group by的区别
    Android中传递对象的三种方法
    设计模式之mvp设计模式
    正则表达式之环视(lookaround)
  • 原文地址:https://www.cnblogs.com/gezhuangzhuang/p/9971230.html
Copyright © 2011-2022 走看看