zoukankan      html  css  js  c++  java
  • 统计学习笔记之感知机

    1.感知机是二分类的线性分类模型,输出为实例的类别,取+1和-1二值。

    2.感知机属于判别模型。

      判别模型:判别方法由学习决策函数f(x)或者是条件概率分布P(Y|X)作为预测的模型。

      生成模型:生成方法由数据联合概率分布P(X,Y),然后求出条件概率分布P(Y|X)作为预测的模型。

    3.感知机的学习策略:首先假设数据是线性可分的,那么需要一个分离超平面把正负实例完全确定分开。为了找到这么一个平面,需要确定的是感知机参数w,b。又如何求得这两个参数值呢?便引入了(经验)损失函数,损失函数的选取则是根据误分类的点到超平面的距离来确定的,使损失函数取得极小化的参数w,b即为所求超平面的参数。

    4.算法

      (1)选取初值w0,b0

      (2)在训练集中选取数据(xi,yi)

      (3)如果yi*(w·xi+b)≤0

        w←w+ηyixi

        b←b+ηyi

      (4)转至(2),直至训练集中没有误分类点。

    5.代码

     1 #include <iostream>
     2 #include <fstream>
     3 using namespace std;
     4 /* run this program using the console pauser or add your own getch, system("pause") or input loop */
     5 
     6 const int num = 10;
     7 int count = 0;
     8 bool flag = true;
     9 
    10 struct DataSet  
    11 {  
    12     double x1;  
    13     double x2;  
    14     int y;  
    15 }data[num];  
    16   
    17 double eta = 1.0;//学习率  
    18 double w[2] = {0.0, 0.0}, b = 0.0;//定义参数  
    19   
    20 //从文件中读取数据  
    21 void readData()  
    22 {  
    23     ifstream file;  
    24     file.open("data.dat");  
    25     int i = 0;  
    26     while(!file.eof())  
    27     {  
    28         file >> data[i].x1 >> data[i].x2 >> data[i].y;  
    29         i++;  
    30         count++;  
    31     }  
    32     file.close();  
    33 }  
    34   
    35 int main()  
    36 {  
    37     int i;  
    38     int n = 0;//迭代次数  
    39     readData();//读入数据  
    40   
    41     //输出数据集  
    42     cout << "数据集为:" << endl;  
    43     for(i = 0; i < count; i++)  
    44     {  
    45         cout << data[i].x1 << "  " << data[i].x2 << "  " << data[i].y << endl;  
    46     }  
    47       
    48     while(flag)  
    49     {  
    50         for(i = 0; i < count; i++)  
    51         {  
    52             flag = false;  
    53             if( data[i].y * (w[0] * data[i].x1 + w[1] * data[i].x2 + b) <= 0)  
    54             {  
    55                 flag = true;  
    56                 w[0] = w[0] + eta * data[i].y * data[i].x1;  
    57                 w[1] = w[1] + eta * data[i].y * data[i].x2;  
    58                 b = b + eta * data[i].y;  
    59                 n++;  
    60                 break;  
    61             }  
    62         }  
    63     }  
    64   
    65     cout << endl << "结果:" << endl;   
    66     cout << "w = " << w[0] << ", " << w[1] << " " << "b=" << b << endl;  
    67     cout << "迭代次数:" << n << endl;  
    68   
    69     return 0;  
    70 }  

    data.dat是《统计学习方法》上例题数据。

    6.感知机学习算法由于采用不同的初值或者选择不同的误分类点,解可以不同。(毕竟随机梯度下降算法只有局部最优解,不能求全局最优解)

    7.感知机对偶形式的基本想法是,将w和b表示为实例xi和标记yi的线性组合形式,通过求解其系数而求得w和b。

      取αi=ni*η,当η=1时,表示第i个实例点由于误分而进行更新的次数。(学习过SVM就会发现,这些点可能是支持向量)

        w←w+ηyixi

        b←b+ηyi

      取初值w=b=0,最后学习到的w,b可以表示为:

      

      将(1)(2)带入原始形式的感知机模型:

      

      那么,学习目标就不再是w,而是ni,满足判别式之后,更新ni←ni+1。

      我们再回到《统计学习方法》的更新方式:

      αi←αi+η

      b←b+ηyi

      由于αi=ni*η,更新ni←ni+1,即αi←(ni+1)*η=αi+η。

      那么,这两种方式便得到了统一。

    8.对偶形式给原始形式带来了什么样的好处呢?

      由于实例是已知的,可以很快求出Gram矩阵,在判断的时候,可以直接调用Gram矩阵的值,来减少计算量!

  • 相关阅读:
    mac Redis相关配置,安装,启动,环境的配置。
    MySQL设置global变量和session变量的两种方法详解
    关于MySQL的锁以及数据脏读,重复读,幻读的笔记。
    MySQL新增数据,存在就更新,不存在就添加(转帖加实测)
    selenium 的显示等待和隐式等待的区别(记录加强版)
    MySQL字段与表的注释。转帖
    mysql格式化日期(转帖)
    通过Python用pymysql,通过sshtunnel模块ssh连接远程数据库。
    java io流
    openID 无效
  • 原文地址:https://www.cnblogs.com/hbwxcw/p/6806154.html
Copyright © 2011-2022 走看看