zoukankan      html  css  js  c++  java
  • 吴裕雄--天生自然python机器学习:基于支持向量机SVM的手写数字识别

     

    from numpy import *
    
    def img2vector(filename):
        returnVect = zeros((1,1024))
        fr = open(filename)
        for i in range(32):
            lineStr = fr.readline()
            for j in range(32):
                returnVect[0,32*i+j] = int(lineStr[j])
        return returnVect
    
    def loadImages(dirName):
        from os import listdir
        hwLabels = []
        trainingFileList = listdir(dirName)           #load the training set
        m = len(trainingFileList)
        trainingMat = zeros((m,1024))
        for i in range(m):
            fileNameStr = trainingFileList[i]
            fileStr = fileNameStr.split('.')[0]     #take off .txt
            classNumStr = int(fileStr.split('_')[0])
            if classNumStr == 9: hwLabels.append(-1)
            else: hwLabels.append(1)
            trainingMat[i,:] = img2vector('%s/%s' % (dirName, fileNameStr))
        return trainingMat, hwLabels 
    
    def smoP(dataMatIn, classLabels, C, toler, maxIter,kTup=('lin', 0)):    #full Platt SMO
        oS = optStruct(mat(dataMatIn),mat(classLabels).transpose(),C,toler, kTup)
        iter = 0
        entireSet = True
        alphaPairsChanged = 0
        while (iter < maxIter) and ((alphaPairsChanged > 0) or (entireSet)):
            alphaPairsChanged = 0
            if entireSet:   #go over all
                for i in range(oS.m):        
                    alphaPairsChanged += innerL(i,oS)
                    print("fullSet, iter: %d i:%d, pairs changed %d" % (iter,i,alphaPairsChanged))
                iter += 1
            else:#go over non-bound (railed) alphas
                nonBoundIs = nonzero((oS.alphas.A > 0) * (oS.alphas.A < C))[0]
                for i in nonBoundIs:
                    alphaPairsChanged += innerL(i,oS)
                    print("non-bound, iter: %d i:%d, pairs changed %d" % (iter,i,alphaPairsChanged))
                iter += 1
            if entireSet: entireSet = False #toggle entire set loop
            elif (alphaPairsChanged == 0): entireSet = True  
            print("iteration number: %d" % iter)
        return oS.b,oS.alphas
    
    def testDigits(kTup=('rbf', 10)):
        dataArr,labelArr = loadImages('F:\machinelearninginaction\Ch06\trainingDigits')
        b,alphas = smoP(dataArr, labelArr, 200, 0.0001, 10000, kTup)
        datMat=mat(dataArr)
        labelMat = mat(labelArr).transpose()
        svInd=nonzero(alphas.A>0)[0]
        sVs=datMat[svInd] 
        labelSV = labelMat[svInd];
        print("there are %d Support Vectors" % shape(sVs)[0])
        m,n = shape(datMat)
        errorCount = 0
        for i in range(m):
            kernelEval = kernelTrans(sVs,datMat[i,:],kTup)
            predict=kernelEval.T * multiply(labelSV,alphas[svInd]) + b
            if sign(predict)!=sign(labelArr[i]): 
                errorCount += 1
        print("the training error rate is: %f" % (float(errorCount)/m))
        dataArr,labelArr = loadImages('F:\machinelearninginaction\Ch06\testDigits')
        errorCount = 0
        datMat=mat(dataArr)
        labelMat = mat(labelArr).transpose()
        m,n = shape(datMat)
        for i in range(m):
            kernelEval = kernelTrans(sVs,datMat[i,:],kTup)
            predict=kernelEval.T * multiply(labelSV,alphas[svInd]) + b
            if sign(predict)!=sign(labelArr[i]): 
                errorCount += 1    
        print("the test error rate is: %f" % (float(errorCount)/m)) 
    testDigits(('rbf',20))

     

     

     

  • 相关阅读:
    lua编程之协程介绍
    lua编程之元表与元方法
    设计模式系列之单例模式
    设计模式系列之生成器模式
    设计模式系列之抽象工厂模式
    设计模式系列之原型模式
    设计模式系列之工厂模式
    stl源码分析之hash table
    2018/2019款 MacBookPro 接口失灵的原因及解决方案
    test
  • 原文地址:https://www.cnblogs.com/tszr/p/12047818.html
Copyright © 2011-2022 走看看