zoukankan      html  css  js  c++  java
  • 支持向量机SMO算法总是死循环

    在学习李航老师的《统计学习方法》中,我用Python编写了SVM的SMO算法。但在编写完成后发现目标函数总是停留在某一个值不再下降,仔细调试发现,每次选择变量时总是选择同一变量,虽然按照计算可以找到使目标函数下降的新值,但是被范围限制约束后又变回之前的值。所以算法总是在这个值上循环。我修改算法,迫使每次选择的必须是不同的变量。但即使如此还是会出现死循环,只是不再是一组变量上,而是多组周期循环。仔细研究了算法,发现书上有一句话很关键

    实践证明这并不是什么特殊情况,而是一般情况。SMO如果变量选择不做特殊处理,很容易进入死循环。我的修改方法很简单,首先把备选的alpha1和alpha2变量排序,再定义一个变量k,从0开始计数。同时在类的成员变量中也记录每次选择的变量编号和相对应的目标函数值。首先查找目标函数值尾部相同的值,寻找到对应的变量编号。那么目标就是不能与这些变量编号重复。根据算法要求首先应该在alpha2上做内层遍历,再在alpha1中做外层遍历。每次查找alpha1用k整除样本数,查找alpha2用k对样本数取余,如果结果不能满足k加一。这样就能满足算法要求。

    这里附上变量选择部分的代码

     1     def getVariableIndex(self):
     2         """选择两个变量"""
     3         isOK=False
     4         k=0
     5         while not isOK:
     6 
     7             row=self.x.shape[0]
     8             kkt=np.zeros(row)
     9             for i in range(row):
    10                 kkt[i]=self.KKT(i)
    11             kktSorted=np.sort(kkt)
    12             index1=np.where(kkt==kktSorted[row-k//row-1])[0][0]
    13 
    14             E=np.sort(self.E)
    15             e1=self.E[index1]
    16             if e1>0:
    17                 index2=np.where(E[k%row]==self.E)[0][0]
    18                 if index2==index1:
    19                     index2=np.where(E[(k+1)%row]==self.E)[0][0]
    20             else:
    21                 index2=np.where(E[row-1-k%row]==self.E)[0][0]
    22                 if index2==index1:
    23                     index2=np.where(E[row-1-(k+1)%row]==self.E)[0][0]
    24             isOK=True
    25             num=len(self.aim)
    26             if num>0:
    27                 aim=self.aim[num-1]
    28                 for i in range(num-1,-1,-1):
    29                     if(aim==self.aim[i]):
    30                         if self.variable[i][0]==index1 and self.variable[i][1]==index2:
    31                             isOK=False
    32                             break
    33                     else:
    34                         break
    35             k += 1
    36 
    37         return index1,index2

    SVM全部代码可以看我的Github:https://github.com/sgdd66/MachineLearning

  • 相关阅读:
    cf B. Sereja and Suffixes
    cf E. Dima and Magic Guitar
    cf D. Dima and Trap Graph
    cf C. Dima and Salad
    最短路径问题(floyd)
    Drainage Ditches(网络流(EK算法))
    图结构练习—BFSDFS—判断可达性(BFS)
    Sorting It All Out(拓扑排序)
    Power Network(最大流(EK算法))
    Labeling Balls(拓扑)
  • 原文地址:https://www.cnblogs.com/sgdd123/p/7824834.html
Copyright © 2011-2022 走看看