zoukankan      html  css  js  c++  java
  • libsvm代码阅读:关于Kernel类分析(转)

    这一篇博文来分析下Kernel类,代码上很简单,一般都能看懂。Kernel类主要是为SVM的核函数服务的,里面实现了SVM常用的核函数,通过函数指针来使用这些核函数。

    其中几个常用核函数如下所示:(一般情况下,使用RBF核函数能取得很好的效果)

    关于基类QMatrix在Kernel中的作用并不明显,只是定义了一些纯虚函数,Kernel继承这些函数,Kernel只对swap_index进行了定义。其余的get_Q和get_QD在Kernel并没有用到。

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_1 height=18 name=ZeroClipboardMovie_1 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=1&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. class QMatrix {  
    2. public:  
    3.     virtual Qfloat *get_Q(int column, int len) const = 0;//纯虚函数,在子类中实现,important!  
    4.     virtual double *get_QD() const = 0;  
    5.     virtual void swap_index(int i, int j) const = 0;  
    6.     virtual ~QMatrix() {}  
    7. };  

    Kernel类的定义函数,比较简单就不细说。

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_2 height=18 name=ZeroClipboardMovie_2 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=2&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
    1. class Kernel: public QMatrix {  
    2. public:  
    3.     Kernel(int l, svm_node * const * x, const svm_parameter& param);  
    4.     virtual ~Kernel();  
    5.   
    6.     static double k_function(const svm_node *x, const svm_node *y,  
    7.                  const svm_parameter& param);  
    8.     virtual Qfloat *get_Q(int column, int len) const = 0;  
    9.     virtual double *get_QD() const = 0;  
    10.     virtual void swap_index(int i, int j) const // no so const...  
    11.     {  
    12.         swap(x[i],x[j]);  
    13.         if(x_square) swap(x_square[i],x_square[j]);  
    14.     }  
    15. protected:  
    16.   
    17.     double (Kernel::*kernel_function)(int i, int j) const;  
    18.   
    19. private:  
    20.     const svm_node **x;//用来指向样本数据,每次数据传入时通过克隆函数来实现,完全重新分配内存,主要是为处理多类着想  
    21.     double *x_square;//使用RBF 核才使用  
    22.   
    23.     // svm_parameter  
    24.     const int kernel_type;  
    25.     const int degree;  
    26.     const double gamma;  
    27.     const double coef0;  
    28.   
    29.     static double dot(const svm_node *px, const svm_node *py);  
    30.   
    31.     double kernel_linear(int i, int j) const  
    32.     {  
    33.         return dot(x[i],x[j]);  
    34.     }  
    35.     double kernel_poly(int i, int j) const  
    36.     {  
    37.         return powi(gamma*dot(x[i],x[j])+coef0,degree);  
    38.     }  
    39.     double kernel_rbf(int i, int j) const  
    40.     {  
    41.         return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));  
    42.     }  
    43.   
    44.     double kernel_sigmoid(int i, int j) const  
    45.     {  
    46.         return tanh(gamma*dot(x[i],x[j])+coef0);  
    47.     }  
    48.     double kernel_precomputed(int i, int j) const  
    49.     {  
    50.         return x[i][(int)(x[j][0].value)].value;  
    51.     }  
    52. };  

    这个Kernel类的函数比较清晰,我直接把它的全部代码贴出。

    全部代码如下:

    [cpp]   view plain copy 在CODE上查看代码片 派生到我的代码片
    <EMBED id=ZeroClipboardMovie_3 height=18 name=ZeroClipboardMovie_3 type=application/x-shockwave-flash align=middle pluginspage=http://www.macromedia.com/go/getflashplayer width=18 src=http://static.blog.csdn.net/scripts/ZeroClipboard/ZeroClipboard.swf wmode="transparent" flashvars="id=3&width=18&height=18" allowfullscreen="false" allowscriptaccess="always" bgcolor="#ffffff" quality="best" menu="false" loop="false">
      1. //  
      2. // Kernel evaluation  
      3. //  
      4. // the static method k_function is for doing single kernel evaluation  
      5. // the constructor of Kernel prepares to calculate the l*l kernel matrix  
      6. // the member function get_Q is for getting one column from the Q Matrix  
      7. //  
      8. class QMatrix {  
      9. public:  
      10.     virtual Qfloat *get_Q(int column, int len) const = 0;  
      11.     virtual double *get_QD() const = 0;  
      12.     virtual void swap_index(int i, int j) const = 0;  
      13.     virtual ~QMatrix() {}  
      14. };  
      15.   
      16. class Kernel: public QMatrix {  
      17. public:  
      18.     Kernel(int l, svm_node * const * x, const svm_parameter& param);//构造函数  
      19.     virtual ~Kernel();  
      20.   
      21.     static double k_function(const svm_node *x, const svm_node *y,  
      22.                  const svm_parameter& param);  
      23.     virtual Qfloat *get_Q(int column, int len) const = 0;  
      24.     virtual double *get_QD() const = 0;  
      25.     virtual void swap_index(int i, int j) const // no so const...  
      26.     {  
      27.         swap(x[i],x[j]);  
      28.         if(x_square) swap(x_square[i],x_square[j]);  
      29.     }  
      30. protected:  
      31.   
      32.     double (Kernel::*kernel_function)(int i, int j) const;  
      33.   
      34. private:  
      35.     const svm_node **x;//用来指向样本数据,每次数据传入时通过克隆函数来实现,完全重新分配内存,主要是为处理多类着想  
      36.     double *x_square;//使用RBF 核才使用  
      37.   
      38.     // svm_parameter  
      39.     const int kernel_type;  
      40.     const int degree;  
      41.     const double gamma;  
      42.     const double coef0;  
      43.   
      44.     static double dot(const svm_node *px, const svm_node *py);  
      45.   
      46.     double kernel_linear(int i, int j) const  
      47.     {  
      48.         return dot(x[i],x[j]);  
      49.     }  
      50.     double kernel_poly(int i, int j) const  
      51.     {  
      52.         return powi(gamma*dot(x[i],x[j])+coef0,degree);  
      53.     }  
      54.     double kernel_rbf(int i, int j) const  
      55.     {  
      56.         return exp(-gamma*(x_square[i]+x_square[j]-2*dot(x[i],x[j])));  
      57.     }  
      58.   
      59.     double kernel_sigmoid(int i, int j) const  
      60.     {  
      61.         return tanh(gamma*dot(x[i],x[j])+coef0);  
      62.     }  
      63.     double kernel_precomputed(int i, int j) const  
      64.     {  
      65.         return x[i][(int)(x[j][0].value)].value;  
      66.     }  
      67. };  
      68.   
      69. //构造函数,初始化类中的部分常量,指定核函数,克隆样本数据。如果使用RBF核函数,则计算x_square[i]  
      70. Kernel::Kernel(int l, svm_node * const * x_, const svm_parameter& param)  
      71. :kernel_type(param.kernel_type), degree(param.degree),  
      72.  gamma(param.gamma), coef0(param.coef0)  
      73. {  
      74.     switch(kernel_type)  
      75.     {  
      76.         case LINEAR:  
      77.             kernel_function = &Kernel::kernel_linear;  
      78.             break;  
      79.         case POLY:  
      80.             kernel_function = &Kernel::kernel_poly;  
      81.             break;  
      82.         case RBF:  
      83.             kernel_function = &Kernel::kernel_rbf;  
      84.             break;  
      85.         case SIGMOID:  
      86.             kernel_function = &Kernel::kernel_sigmoid;  
      87.             break;  
      88.         case PRECOMPUTED:  
      89.             kernel_function = &Kernel::kernel_precomputed;  
      90.             break;  
      91.     }  
      92.   
      93.     clone(x,x_,l);//void clone(T*& dst, S* src, int n)  
      94.   
      95.     if(kernel_type == RBF)  
      96.     {  
      97.         x_square = new double[l];  
      98.         for(int i=0;i<l;i++)  
      99.             x_square[i] = dot(x[i],x[i]);  
      100.     }  
      101.     else  
      102.         x_square = 0;  
      103. }  
      104.   
      105. Kernel::~Kernel()  
      106. {  
      107.     delete[] x;  
      108.     delete[] x_square;  
      109. }  
      110.   
      111. double Kernel::dot(const svm_node *px, const svm_node *py)  
      112. {  
      113.     double sum = 0;  
      114.     while(px->index != -1 && py->index != -1)  
      115.     {  
      116.         if(px->index == py->index)  
      117.         {  
      118.             sum += px->value * py->value;  
      119.             ++px;  
      120.             ++py;  
      121.         }  
      122.         else  
      123.         {  
      124.             if(px->index > py->index)  
      125.                 ++py;  
      126.             else  
      127.                 ++px;  
      128.         }             
      129.     }  
      130.     return sum;  
      131. }  
      132.   
      133. double Kernel::k_function(const svm_node *x, const svm_node *y,  
      134.               const svm_parameter& param)  
      135. {  
      136.     switch(param.kernel_type)  
      137.     {  
      138.         case LINEAR:  
      139.             return dot(x,y);  
      140.         case POLY:  
      141.             return powi(param.gamma*dot(x,y)+param.coef0,param.degree);  
      142.         case RBF:  
      143.         {  
      144.             double sum = 0;  
      145.             while(x->index != -1 && y->index !=-1)  
      146.             {  
      147.                 if(x->index == y->index)  
      148.                 {  
      149.                     double d = x->value - y->value;  
      150.                     sum += d*d;  
      151.                     ++x;  
      152.                     ++y;  
      153.                 }  
      154.                 else  
      155.                 {  
      156.                     if(x->index > y->index)  
      157.                     {     
      158.                         sum += y->value * y->value;  
      159.                         ++y;  
      160.                     }  
      161.                     else  
      162.                     {  
      163.                         sum += x->value * x->value;  
      164.                         ++x;  
      165.                     }  
      166.                 }  
      167.             }  
      168.   
      169.             while(x->index != -1)  
      170.             {  
      171.                 sum += x->value * x->value;  
      172.                 ++x;  
      173.             }  
      174.   
      175.             while(y->index != -1)  
      176.             {  
      177.                 sum += y->value * y->value;  
      178.                 ++y;  
      179.             }  
      180.               
      181.             return exp(-param.gamma*sum);  
      182.         }  
      183.         case SIGMOID:  
      184.             return tanh(param.gamma*dot(x,y)+param.coef0);  
      185.         case PRECOMPUTED:  //x: test (validation), y: SV  
      186.             return x[(int)(y->value)].value;  
      187.         default:  
      188.             return 0;  // Unreachable   
      189.     }  
      190. }  
  • 相关阅读:
    MYSQL数据库导入SQL文件出现乱码解决方法
    Mysql设置允许用户可以连接
    MongoDB 设置权限认证
    NodeJs 服务端调试
    Hudson 定时编译
    Ubuntu上NodeJs环境安装
    新开通博客
    war类型项目创建
    Maven项目创建
    Maven简介与配置
  • 原文地址:https://www.cnblogs.com/Miliery/p/4394138.html
Copyright © 2011-2022 走看看