zoukankan      html  css  js  c++  java
  • SVM算法Python实现

      1 class SVM:
      2     def __init__(self, max_iter=100, kernel='linear'):
      3         self.max_iter = max_iter
      4         self._kernel = kernel
      5 
      6     def init_args(self, features, labels):
      7         self.m, self.n = features.shape
      8         self.X = features
      9         self.Y = labels
     10         self.b = 0.0
     11 
     12         # 将Ei保存在一个列表里
     13         self.alpha = np.ones(self.m)
     14         self.E = [self._E(i) for i in range(self.m)]
     15         # 松弛变量
     16         self.C = 1.0
     17 
     18     def _KKT(self, i):
     19         y_g = self._g(i) * self.Y[i]
     20         if self.alpha[i] == 0:
     21             return y_g >= 1
     22         elif 0 < self.alpha[i] < self.C:
     23             return y_g == 1
     24         else:
     25             return y_g <= 1
     26 
     27     # g(x)预测值,输入xi(X[i])
     28     def _g(self, i):
     29         r = self.b
     30         for j in range(self.m):
     31             r += self.alpha[j] * self.Y[j] * self.kernel(self.X[i], self.X[j])
     32         return r
     33 
     34     # 核函数
     35     def kernel(self, x1, x2):
     36         if self._kernel == 'linear':
     37             return sum([x1[k] * x2[k] for k in range(self.n)])
     38         elif self._kernel == 'poly':
     39             return (sum([x1[k] * x2[k] for k in range(self.n)]) + 1)**2
     40 
     41         return 0
     42 
     43     # E(x)为g(x)对输入x的预测值和y的差
     44     def _E(self, i):
     45         return self._g(i) - self.Y[i]
     46 
     47     def _init_alpha(self):
     48         # 外层循环首先遍历所有满足0<a<C的样本点,检验是否满足KKT
     49         index_list = [i for i in range(self.m) if 0 < self.alpha[i] < self.C]
     50         # 否则遍历整个训练集
     51         non_satisfy_list = [i for i in range(self.m) if i not in index_list]
     52         index_list.extend(non_satisfy_list)
     53 
     54         for i in index_list:
     55             if self._KKT(i):
     56                 continue
     57 
     58             E1 = self.E[i]
     59             # 如果E1是+,选择最小的;如果E1是负的,选择最大的
     60             if E1 >= 0:
     61                 j = min(range(self.m), key=lambda x: self.E[x])
     62             else:
     63                 j = max(range(self.m), key=lambda x: self.E[x])
     64             return i, j
     65 
     66     def _compare(self, _alpha, L, H):
     67         if _alpha > H:
     68             return H
     69         elif _alpha < L:
     70             return L
     71         else:
     72             return _alpha
     73 
     74     def fit(self, features, labels):
     75         self.init_args(features, labels)
     76 
     77         for t in range(self.max_iter):
     78             # train
     79             i1, i2 = self._init_alpha()
     80 
     81             # 设定边界
     82             if self.Y[i1] == self.Y[i2]:
     83                 L = max(0, self.alpha[i1] + self.alpha[i2] - self.C)
     84                 H = min(self.C, self.alpha[i1] + self.alpha[i2])
     85             else:
     86                 L = max(0, self.alpha[i2] - self.alpha[i1])
     87                 H = min(self.C, self.C + self.alpha[i2] - self.alpha[i1])
     88 
     89             E1 = self.E[i1]
     90             E2 = self.E[i2]
     91             # eta=K11+K22-2K12
     92             eta = self.kernel(self.X[i1], self.X[i1]) + self.kernel(
     93                 self.X[i2],
     94                 self.X[i2]) - 2 * self.kernel(self.X[i1], self.X[i2])
     95             if eta <= 0:
     96                 # print('eta <= 0')
     97                 continue
     98 
     99             alpha2_new_unc = self.alpha[i2] + self.Y[i2] * (
    100                 E1 - E2) / eta  # 第二版统计学习方法 P142~148
    101
    alpha2_new = self._compare(alpha2_new_unc, L, H) # 若超过边界,将alpha设定为边界值 102 103 alpha1_new = self.alpha[i1] + self.Y[i1] * self.Y[i2] * ( 104 self.alpha[i2] - alpha2_new) 105 106 b1_new = -E1 - self.Y[i1] * self.kernel(self.X[i1], self.X[i1]) * ( 107 alpha1_new - self.alpha[i1]) - self.Y[i2] * self.kernel( 108 self.X[i2], 109 self.X[i1]) * (alpha2_new - self.alpha[i2]) + self.b 110 b2_new = -E2 - self.Y[i1] * self.kernel(self.X[i1], self.X[i2]) * ( 111 alpha1_new - self.alpha[i1]) - self.Y[i2] * self.kernel( 112 self.X[i2], 113 self.X[i2]) * (alpha2_new - self.alpha[i2]) + self.b 114 115 if 0 < alpha1_new < self.C: 116 b_new = b1_new 117 elif 0 < alpha2_new < self.C: 118 b_new = b2_new 119 else: 120 # 选择中点 121 b_new = (b1_new + b2_new) / 2 122 123 # 更新参数 124 self.alpha[i1] = alpha1_new 125 self.alpha[i2] = alpha2_new 126 self.b = b_new 127 128 self.E[i1] = self._E(i1) 129 self.E[i2] = self._E(i2) 130 return 'train done!' 131 132 def predict(self, data): 133 r = self.b 134 for i in range(self.m): 135 r += self.alpha[i] * self.Y[i] * self.kernel(data, self.X[i]) 136 137 return 1 if r > 0 else -1 138 139 def score(self, X_test, y_test): 140 right_count = 0 141 for i in range(len(X_test)): 142 result = self.predict(X_test[i]) 143 if result == y_test[i]: 144 right_count += 1 145 return right_count / len(X_test) 146 147 def _weight(self): 148 # linear model 149 yx = self.Y.reshape(-1, 1) * self.X 150 self.w = np.dot(yx.T, self.alpha) 151 return self.w

    参照周志华西瓜书关于SVM的推导,有助于理解

    程序可以结合机器学习实战关于SVM代码部分学习

  • 相关阅读:
    IE window对象跨域的一些特性
    杭州归来
    网上流行的JS HTMLDecode不安全
    看到的一点进步
    开春第一趟单骑上妙峰
    把JS函数转URL形式
    firebug也支持debugger关键字了
    发现一篇关于flash垃圾回收机制的文章
    Java 内存分析图
    继承中多态的灵活使用及其分析图 第一个程序的升级版
  • 原文地址:https://www.cnblogs.com/ningjing213/p/13891084.html
Copyright © 2011-2022 走看看