zoukankan      html  css  js  c++  java
  • 数学之路(3)-机器学习(3)-机器学习算法-SVM[7]

    本博客所有内容是原创,未经书面许可,严禁任何形式的转载

    http://blog.csdn.net/u010255642

    根据SMO的算法描述,用python实现,部分代码如下,定义了一个svm_pmcp类,所有的运算在svm_pmcp完成,这样便于封装和实际应用

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    #麦好:myhaspl@qq.com
    #http://blog.csdn.net/u010255642
    #svm算法
    import numpy as np import math
    import matplotlib.pyplot as plt
    
    #内积线性核函数
    def arraydot(x,y):
        return x.T*y
    
    #svm参数与计算类
    class svm_pmcp:
        def __init__(self):
            '''初始化参数变量'''
            self.alpha = []
            self.samples=[]
            self.labels=[]
            self.boundalpha=[]
        def samples_init(self,samples):
            '''样子及乘子参数初始化'''
            for (mysp,mylb) in samples:
                self.samples.append(mysp)
                self.labels.append(mylb)
            #初始化拉格朗日乘子alpha为0
            for i in xrange(0,len(self.samples)):
                self.alpha.append(0)
            #初始化b为0
            self.b = 0
        def kernel_init(self,func):
            '''指定核函数'''
            self.kernel_func=func
        def lagrange_multiplier(self,i):
            '''求拉格朗日乘子'''
            pass
        def svmoutput(self,i):
            pass
        def tol_init(self,mytol):
            self.tol=mytol
        def eps_inbit(self,myeps):
            self.eps=myeps
        def c_init(selfm,myc):
            self.c=myc
        def choicesecond_max(self,nte):
            pass
        def choicesecond_random(self):
            pass
        def get_lh(self,i,j):
            pass
        def update_b(self):
            pass
        def update_w(self):
            pass
        def alpha_nozero_noc(self):
            pass
        def store_alpha(self,i1,a1,i2,a2):
            pass
    
        def takestep(i1,i2,e2,alpha2):
            if (i1==i2):
                return False
            alpha1=lagrange_multiplier(i1)
            y1=labels[i1]
            e1=svmoutput(i1)-y1
            s=y1*y2
            l,h=get_lh(i2,i1)
            if l==h:
                return False
            k11=kernel_func(self.samples[i1],self.samples[i1])
            k12=kernel_func(self.samples[i1],self.samples[i2])
            k13=kernel_func(self.samples[i2],self.samples[i2])
            eta=float(2*k12-k11-k22)
            if (eta<0):
                a2=alpha2-y2*(e1-e2))/eta
                if a2<l:
                    a2=l
                elif a2>h:
                    a2=h
            else:
                lobj=obfuncl()
                hobj=obfunch()
                if lobj>hobj+self.eps:
                    a2=l
                elif lobj<hobj-self.eps:
                    a2=h
                else:
                    a2=alpha2
            if abs(a2-alpha2)<self.eps*(a2+alph2+self.eps):
                return False
            a1=alpha1+s*(alpha2-a2)
            update_b()
            update_w()
            store_alpha(i1,a1,i2,a2)
            return True
            
                    
                
                
                
    
    
    
        def examineexample(myi):
            y2=labels[myi]
            alpha2=lagrange_multiplier(myi)
            e2=svmoutput(myi)-y2
            r2=e2*y2
            if  ((r2<-self.tol and alpha2<self.c) or (r2>self.tol and alpha2>0):
                 if (len(self.boundalpha)>0):
                     myj=choicesecond_max(e)
                     if takestep(myj,myi,e2,alpha2):
                         return 1
                 else:
                     myj=choicesecond_random(myi)
                     if takestep(myj,myi,e2,alpha2):
                         return 1
            return 0
        
        def loop1(self,nc):
            for i in xrange(0,len(mysvm.samples)):
                nc+=examineexample(i)
        def loop2(self,nc):
            for i in alpha_nozero_noc():
                nc+=examineexample(i)
    
        def mainroutine(self):
            numchanged=0
            examineall=True
            while (numchanged>0 or examineall):
                numchanged=0
                if examineall:
                    numchanged=loop1(numchanged)
                else:
                    numchanged=loop2(numchanged)
                examineall=not examineall
    
    
    
    
    
    
    
    
    
    
    def mainsvm(mysamples):
        mysvm = svm_pmcp()
        mysvm.samples_init(mysamples)
        mysvm.kernel_init(arraydot)
        mysvm.tol_init(0.001)
        mysvm.eps_init(0.00001)
        mysvm.c_init(1)
        mysvm.mainroutine()
    
    
    
    
    
    
    
    
    
    

    后面关于svm的章节将提供类下载地址及调用代码

  • 相关阅读:
    huffman编码
    查询选修了全部课程的学生姓名【转】
    java中的内部类
    jdbc技术
    多态
    存储过程,触发器,Mysql权限,备份还原
    sql查询重点
    常用的sql基础
    手动+工具开发动态资源
    Tomcat
  • 原文地址:https://www.cnblogs.com/snake-hand/p/3206483.html
Copyright © 2011-2022 走看看