zoukankan      html  css  js  c++  java
  • libsvm代码阅读:关于Solver类分析(一)(转)

    如果你看完了上篇博文的伪代码,那么我们就可以开始谈谈它的源代码了。

    下面先贴出它的类定义,一些成员函数的具体实现先忽略。

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. // An SMO algorithm in Fan et al., JMLR 6(2005), p. 1889--1918  
    2. // Solves:  
    3. //  min 0.5(alpha^T Q alpha) + p^T alpha  
    4. //  
    5. //      y^T alpha = delta  
    6. //      y_i = +1 or -1  
    7. //      0 <= alpha_i <= Cp for y_i = 1  
    8. //      0 <= alpha_i <= Cn for y_i = -1  
    9. //  
    10. // Given:  
    11. //  Q, p, y, Cp, Cn, and an initial feasible point alpha  
    12. //  l is the size of vectors and matrices  
    13. //  eps is the stopping tolerance  
    14. // solution will be put in alpha, objective value will be put in obj  
    15. //  
    16. class Solver {  
    17. public:  
    18.     Solver() {};  
    19.     virtual ~Solver() {};//用虚析构函数的原因是:保证根据实际运行适当的析构函数  
    20.   
    21.     struct SolutionInfo {  
    22.         double obj;  
    23.         double rho;  
    24.         double upper_bound_p;  
    25.         double upper_bound_n;  
    26.         double r;   // for Solver_NU  
    27.     };  
    28.   
    29.     void Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,  
    30.            double *alpha_, double Cp, double Cn, double eps,  
    31.            SolutionInfo* si, int shrinking);  
    32. protected:  
    33.     int active_size;//计算时实际参加运算的样本数目,经过shrink处理后,该数目小于全部样本数  
    34.     schar *y;       //样本所属类别,该值只能取-1或+1。  
    35.     double *G;      // gradient of objective function = (Q alpha + p)  
    36.     enum { LOWER_BOUND, UPPER_BOUND, FREE };  
    37.     char *alpha_status; // LOWER_BOUND, UPPER_BOUND, FREE   
    38.     double *alpha;      //  
    39.     const QMatrix *Q;     
    40.     const double *QD;  
    41.     double eps;     //误差限  
    42.     double Cp,Cn;  
    43.     double *p;  
    44.     int *active_set;  
    45.     double *G_bar;      // gradient, if we treat free variables as 0  
    46.     int l;  
    47.     bool unshrink;  // XXX  
    48.     //返回对应于样本的C。设置不同的Cp和Cn是为了处理数据的不平衡  
    49.     double get_C(int i)  
    50.     {  
    51.         return (y[i] > 0)? Cp : Cn;  
    52.     }  
    53.   
    54.     void update_alpha_status(int i)  
    55.     {  
    56.         if(alpha[i] >= get_C(i))  
    57.             alpha_status[i] = UPPER_BOUND;  
    58.         else if(alpha[i] <= 0)  
    59.             alpha_status[i] = LOWER_BOUND;  
    60.         else alpha_status[i] = FREE;  
    61.     }  
    62.     bool is_upper_bound(int i) { return alpha_status[i] == UPPER_BOUND; }  
    63.     bool is_lower_bound(int i) { return alpha_status[i] == LOWER_BOUND; }  
    64.     bool is_free(int i) { return alpha_status[i] == FREE; }  
    65.     void swap_index(int i, int j);//交换样本i和j的内容,包括申请的内存的地址  
    66.     void reconstruct_gradient();  //重新计算梯度。  
    67.     virtual int select_working_set(int &i, int &j);//选择工作集  
    68.     virtual double calculate_rho();  
    69.     virtual void do_shrinking();//对样本集做缩减。  
    70. private:  
    71.     bool be_shrunk(int i, double Gmax1, double Gmax2);    
    72. };  

    下面我们来看看SMO如何选择工作集(working set B),选择的约束如下:

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_2 height=18 name=ZeroClipboardMovie_2 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=2&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. // return i,j such that  
    2. // i: maximizes -y_i * grad(f)_i, i in I_up(alpha)  
    3. // j: minimizes the decrease of obj value  
    4. //    (if quadratic coefficeint <= 0, replace it with tau)  
    5. //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(alpha)  

    论文中的公式如下:

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_3 height=18 name=ZeroClipboardMovie_3 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=3&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. int Solver::select_working_set(int &out_i, int &out_j)  
    2. {  
    3.     // return i,j such that  
    4.     // i: maximizes -y_i * grad(f)_i, i in I_up(alpha)  
    5.     // j: minimizes the decrease of obj value  
    6.     //    (if quadratic coefficeint <= 0, replace it with tau)  
    7.     //    -y_j*grad(f)_j < -y_i*grad(f)_i, j in I_low(alpha)  
    8. //select i    
    9.     double Gmax = -INF;  
    10.     double Gmax2 = -INF;  
    11.     int Gmax_idx = -1;  
    12.     int Gmin_idx = -1;  
    13.     double obj_diff_min = INF;  
    14.   
    15.     for(int t=0;t<active_size;t++)  
    16.         if(y[t]==+1)    //若类别为1  
    17.         {  
    18.             if(!is_upper_bound(t))//若alpha<C  
    19.                 if(-G[t] >= Gmax)  
    20.                 {  
    21.                     Gmax = -G[t];// -y[t]*G[t]=-1*G[t]  
    22.                     Gmax_idx = t;  
    23.                 }  
    24.         }  
    25.         else  
    26.         {  
    27.             if(!is_lower_bound(t))  
    28.                 if(G[t] >= Gmax)  
    29.                 {  
    30.                     Gmax = G[t];  
    31.                     Gmax_idx = t;  
    32.                 }  
    33.         }  
    34.   
    35.     int i = Gmax_idx;  
    36.     const Qfloat *Q_i = NULL;  
    37.     if(i != -1) // NULL Q_i not accessed: Gmax=-INF if i=-1  
    38.         Q_i = Q->get_Q(i,active_size);  
    39. //select j  
    40.     for(int j=0;j<active_size;j++)  
    41.     {  
    42.         if(y[j]==+1)  
    43.         {  
    44.             if (!is_lower_bound(j))  
    45.             {  
    46.                 double grad_diff=Gmax+G[j];  
    47.                 if (G[j] >= Gmax2)  
    48.                     Gmax2 = G[j];  
    49.                 if (grad_diff > 0)  
    50.                 {  
    51.                     double obj_diff;   
    52.                     double quad_coef = QD[i]+QD[j]-2.0*y[i]*Q_i[j];  
    53.                     if (quad_coef > 0)  
    54.                         obj_diff = -(grad_diff*grad_diff)/quad_coef;  
    55.                     else  
    56.                         obj_diff = -(grad_diff*grad_diff)/TAU;  
    57.   
    58.                     if (obj_diff <= obj_diff_min)  
    59.                     {  
    60.                         Gmin_idx=j;  
    61.                         obj_diff_min = obj_diff;  
    62.                     }  
    63.                 }  
    64.             }  
    65.         }  
    66.         else  
    67.         {  
    68.             if (!is_upper_bound(j))  
    69.             {  
    70.                 double grad_diff= Gmax-G[j];  
    71.                 if (-G[j] >= Gmax2)  
    72.                     Gmax2 = -G[j];  
    73.                 if (grad_diff > 0)  
    74.                 {  
    75.                     double obj_diff;   
    76.                     double quad_coef = QD[i]+QD[j]+2.0*y[i]*Q_i[j];  
    77.                     if (quad_coef > 0)  
    78.                         obj_diff = -(grad_diff*grad_diff)/quad_coef;  
    79.                     else  
    80.                         obj_diff = -(grad_diff*grad_diff)/TAU;  
    81.   
    82.                     if (obj_diff <= obj_diff_min)  
    83.                     {  
    84.                         Gmin_idx=j;  
    85.                         obj_diff_min = obj_diff;  
    86.                     }  
    87.                 }  
    88.             }  
    89.         }  
    90.     }  
    91.   
    92.     if(Gmax+Gmax2 < eps)  
    93.         return 1;  
    94.   
    95.     out_i = Gmax_idx;  
    96.     out_j = Gmin_idx;  
    97.     return 0;  
    98. }  

    配合上面几个公式看,这段代码还是很清晰了。

    下面来看看它的构造函数,这个构造函数是solver类的核心。这个算法也结合上一篇博文的algorithm2来看。其中要注意的是get_Q是获取核函数。

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_4 height=18 name=ZeroClipboardMovie_4 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=4&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. void Solver::Solve(int l, const QMatrix& Q, const double *p_, const schar *y_,  
    2.            double *alpha_, double Cp, double Cn, double eps,  
    3.            SolutionInfo* si, int shrinking)  
    4. {  
    5.     this->l = l;  
    6.     this->Q = &Q;  
    7.     QD=Q.get_QD();//这个是获取核函数(如果分类的话在SVC_Q中定义)  
    8.   
    9.     clone(p, p_,l);  
    10.     clone(y, y_,l);  
    11.     clone(alpha,alpha_,l);  
    12.   
    13.     this->Cp = Cp;  
    14.     this->Cn = Cn;  
    15.     this->eps = eps;  
    16.     unshrink = false;  
    17.   
    18.     // initialize alpha_status  
    19.     {  
    20.         alpha_status = new char[l];  
    21.         for(int i=0;i<l;i++)  
    22.             update_alpha_status(i);  
    23.     }  
    24.   
    25.     // initialize active set (for shrinking)  
    26.     {  
    27.         active_set = new int[l];  
    28.         for(int i=0;i<l;i++)  
    29.             active_set[i] = i;  
    30.         active_size = l;  
    31.     }  
    32.   
    33.     // initialize gradient  
    34.     {  
    35.         G = new double[l];  
    36.         G_bar = new double[l];  
    37.         int i;  
    38.         for(i=0;i<l;i++)  
    39.         {  
    40.             G[i] = p[i];  
    41.             G_bar[i] = 0;  
    42.         }  
    43.         for(i=0;i<l;i++)  
    44.             if(!is_lower_bound(i))  
    45.             {  
    46.                 const Qfloat *Q_i = Q.get_Q(i,l);  
    47.                 double alpha_i = alpha[i];  
    48.                 int j;  
    49.                 for(j=0;j<l;j++)  
    50.                     G[j] += alpha_i*Q_i[j];  
    51.                 if(is_upper_bound(i))  
    52.                     for(j=0;j<l;j++)  
    53.                         G_bar[j] += get_C(i) * Q_i[j]; //这里见文献LIBSVM: A Library for SVM公式(33)  
    54.             }  
    55.     }  
    56.   
    57.     // optimization step  
    58.   
    59.     int iter = 0;  
    60.     int max_iter = max(10000000, l>INT_MAX/100 ? INT_MAX : 100*l);  
    61.     int counter = min(l,1000)+1;  
    62.       
    63.     while(iter < max_iter)  
    64.     {  
    65.         // show progress and do shrinking  
    66.   
    67.         if(--counter == 0)  
    68.         {  
    69.             counter = min(l,1000);  
    70.             if(shrinking) do_shrinking();    
    71.             info(".");  
    72.         }  
    73.   
    74.         int i,j;  
    75.         if(select_working_set(i,j)!=0)  
    76.         {  
    77.             // reconstruct the whole gradient  
    78.             reconstruct_gradient();  
    79.             // reset active set size and check  
    80.             active_size = l;  
    81.             info("*");  
    82.             if(select_working_set(i,j)!=0)  
    83.                 break;  
    84.             else  
    85.                 counter = 1;    // do shrinking next iteration  
    86.         }  
    87.           
    88.         ++iter;  
    89.   
    90.         // update alpha[i] and alpha[j], handle bounds carefully  
    91.           
    92.         const Qfloat *Q_i = Q.get_Q(i,active_size);  
    93.         const Qfloat *Q_j = Q.get_Q(j,active_size);  
    94.   
    95.         double C_i = get_C(i);  
    96.         double C_j = get_C(j);  
    97.   
    98.         double old_alpha_i = alpha[i];  
    99.         double old_alpha_j = alpha[j];  
    100.   
    101.         if(y[i]!=y[j])  
    102.         {  
    103.             double quad_coef = QD[i]+QD[j]+2*Q_i[j];  
    104.             if (quad_coef <= 0)  
    105.                 quad_coef = TAU;  
    106.             double delta = (-G[i]-G[j])/quad_coef;  
    107.             double diff = alpha[i] - alpha[j];  
    108.             alpha[i] += delta;  
    109.             alpha[j] += delta;  
    110.               
    111.             if(diff > 0)  
    112.             {  
    113.                 if(alpha[j] < 0)  
    114.                 {  
    115.                     alpha[j] = 0;  
    116.                     alpha[i] = diff;  
    117.                 }  
    118.             }  
    119.             else  
    120.             {  
    121.                 if(alpha[i] < 0)  
    122.                 {  
    123.                     alpha[i] = 0;  
    124.                     alpha[j] = -diff;  
    125.                 }  
    126.             }  
    127.             if(diff > C_i - C_j)  
    128.             {  
    129.                 if(alpha[i] > C_i)  
    130.                 {  
    131.                     alpha[i] = C_i;  
    132.                     alpha[j] = C_i - diff;  
    133.                 }  
    134.             }  
    135.             else  
    136.             {  
    137.                 if(alpha[j] > C_j)  
    138.                 {  
    139.                     alpha[j] = C_j;  
    140.                     alpha[i] = C_j + diff;  
    141.                 }  
    142.             }  
    143.         }  
    144.         else  
    145.         {  
    146.             double quad_coef = QD[i]+QD[j]-2*Q_i[j];  
    147.             if (quad_coef <= 0)  
    148.                 quad_coef = TAU;  
    149.             double delta = (G[i]-G[j])/quad_coef;  
    150.             double sum = alpha[i] + alpha[j];  
    151.             alpha[i] -= delta;  
    152.             alpha[j] += delta;  
    153.   
    154.             if(sum > C_i)  
    155.             {  
    156.                 if(alpha[i] > C_i)  
    157.                 {  
    158.                     alpha[i] = C_i;  
    159.                     alpha[j] = sum - C_i;  
    160.                 }  
    161.             }  
    162.             else  
    163.             {  
    164.                 if(alpha[j] < 0)  
    165.                 {  
    166.                     alpha[j] = 0;  
    167.                     alpha[i] = sum;  
    168.                 }  
    169.             }  
    170.             if(sum > C_j)  
    171.             {  
    172.                 if(alpha[j] > C_j)  
    173.                 {  
    174.                     alpha[j] = C_j;  
    175.                     alpha[i] = sum - C_j;  
    176.                 }  
    177.             }  
    178.             else  
    179.             {  
    180.                 if(alpha[i] < 0)  
    181.                 {  
    182.                     alpha[i] = 0;  
    183.                     alpha[j] = sum;  
    184.                 }  
    185.             }  
    186.         }  
    187.   
    188.         // update G  
    189.   
    190.         double delta_alpha_i = alpha[i] - old_alpha_i;  
    191.         double delta_alpha_j = alpha[j] - old_alpha_j;  
    192.           
    193.         for(int k=0;k<active_size;k++)  
    194.         {  
    195.             G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;  
    196.         }  
    197.   
    198.         // update alpha_status and G_bar  
    199.   
    200.         {  
    201.             bool ui = is_upper_bound(i);  
    202.             bool uj = is_upper_bound(j);  
    203.             update_alpha_status(i);  
    204.             update_alpha_status(j);  
    205.             int k;  
    206.             if(ui != is_upper_bound(i))  
    207.             {  
    208.                 Q_i = Q.get_Q(i,l);  
    209.                 if(ui)  
    210.                     for(k=0;k<l;k++)  
    211.                         G_bar[k] -= C_i * Q_i[k];  
    212.                 else  
    213.                     for(k=0;k<l;k++)  
    214.                         G_bar[k] += C_i * Q_i[k];  
    215.             }  
    216.   
    217.             if(uj != is_upper_bound(j))  
    218.             {  
    219.                 Q_j = Q.get_Q(j,l);  
    220.                 if(uj)  
    221.                     for(k=0;k<l;k++)  
    222.                         G_bar[k] -= C_j * Q_j[k];  
    223.                 else  
    224.                     for(k=0;k<l;k++)  
    225.                         G_bar[k] += C_j * Q_j[k];  
    226.             }  
    227.         }  
    228.     }  
    229.   
    230.     if(iter >= max_iter)  
    231.     {  
    232.         if(active_size < l)  
    233.         {  
    234.             // reconstruct the whole gradient to calculate objective value  
    235.             reconstruct_gradient();  
    236.             active_size = l;  
    237.             info("*");  
    238.         }  
    239.         fprintf(stderr," WARNING: reaching max number of iterations ");  
    240.     }  
    241.   
    242.     // calculate rho  
    243.   
    244.     si->rho = calculate_rho();  
    245.   
    246.     // calculate objective value  
    247.     {  
    248.         double v = 0;  
    249.         int i;  
    250.         for(i=0;i<l;i++)  
    251.             v += alpha[i] * (G[i] + p[i]);  
    252.   
    253.         si->obj = v/2;  
    254.     }  
    255.   
    256.     // put back the solution  
    257.     {  
    258.         for(int i=0;i<l;i++)  
    259.             alpha_[active_set[i]] = alpha[i];  
    260.     }  
    261.   
    262.     // juggle everything back  
    263.     /*{ 
    264.         for(int i=0;i<l;i++) 
    265.             while(active_set[i] != i) 
    266.                 swap_index(i,active_set[i]); 
    267.                 // or Q.swap_index(i,active_set[i]); 
    268.     }*/  
    269.   
    270.     si->upper_bound_p = Cp;  
    271.     si->upper_bound_n = Cn;  
    272.   
    273.     info(" optimization finished, #iter = %d ",iter);  
    274.   
    275.     delete[] p;  
    276.     delete[] y;  
    277.     delete[] alpha;  
    278.     delete[] alpha_status;  
    279.     delete[] active_set;  
    280.     delete[] G;  
    281.     delete[] G_bar;  
    282. }  
  • 相关阅读:
    《Machine Learning in Action》—— 白话贝叶斯,“恰瓜群众”应该恰好瓜还是恰坏瓜
    《Machine Learning in Action》—— 女同学问Taoye,KNN应该怎么玩才能通关
    《Machine Learning in Action》—— Taoye给你讲讲决策树到底是支什么“鬼”
    深度学习炼丹术 —— Taoye不讲码德,又水文了,居然写感知器这么简单的内容
    《Machine Learning in Action》—— 浅谈线性回归的那些事
    《Machine Learning in Action》—— 懂的都懂,不懂的也能懂。非线性支持向量机
    《Machine Learning in Action》—— hao朋友,快来玩啊,决策树呦
    《Machine Learning in Action》—— 剖析支持向量机,优化SMO
    《Machine Learning in Action》—— 剖析支持向量机,单手狂撕线性SVM
    JVM 字节码指令
  • 原文地址:https://www.cnblogs.com/Miliery/p/4394140.html
Copyright © 2011-2022 走看看