zoukankan      html  css  js  c++  java
  • 统计学习方法(2)——感知机

    感知机是二类分类的线性分类模型,其输入为实例的特征向量,输出为实例的类别{-1,1},是一种判别模型。感知机学习的目的在于求出将训练数据进行划分的超平面。

    • 感知机模型


    输入空间(Xepsilon R^{n}),输出空间(gamma =left { -1,1 ight })

    [ f(x)=sign(wcdot x+b)$$ $x$为输入向量,其中,$w$和$b$为感知机模型参数,$wcdot b$表示内积,sign是符号函数。感知机的几何角度理解是:$$wcdot x+b=0$$是特征空间$R^{n}$的一个超平面,$w$是该平面的法向量,$b$是截距。这个超平面将特征空间划分为正负两个部分,如下图。 ![](https://upload-images.jianshu.io/upload_images/7490554-d306524a627a15b3.png?imageMogr2/auto-orient/strip%7CimageView2/2/w/1240) - ####感知机学习策略 *** 感知机学习的目的是为了找到能够将正负实例点正确分开的超平面,也就是要确定参数$w$和$b$,感知机的学习策略便是__定义一个损失函数并将其最小化__。于是便要选择一个损失函数的依据,可以选择误分类的点的数量作为损失函数,然而该函数不可导,不易于优化,因此选择误分类点到超平面的距离和:$$frac{left | wcdot x +b ight |}{left | w ight |}$$ 此处${left | w ight |}$是$w$的第二范数。注意需要优化的只是误分类的点,对于误分类的点有,$$-y_i(wcdot x + b)>0$$恒成立,因此可去掉绝对值符号,并假设当前超平面的误分类的点的集合为M,由此得到感知机学习的损失函数为$$L(w,b)=-sum_{x_iin M}y_i(wcdot x_i+b)$$ 其中M为误分类的点的集合。显然该损失函数是非负的,当没有误分类的点时$L(w,b)=0$.只需将损失函数优化到0即得到该分类超平面,不过由该方法得到的超平面的解不是唯一的(显然只需要能够正确分类时算法即停止)。 - ####感知机学习算法 *** 感知机所用优化方法是随机梯度下降法,包括原始形式和对偶形式。 1. ######原始形式 前面已经确定了感知机的损失函数,那么其原始形式只需要最小化这个损失函数即可。 $$underset{w,b}{min}L(w,b)=-sum_{x_iin M}y_i(wcdot x+b)$$其中M为误分类的点的集合。 随机梯度下降法初始时任选$w_0$,$b_0$作为初始超平面,计算有哪些误分类点,如果有误分类点,随机选取一个误分类点,进行梯度下降。即先计算损失函数的梯度 ]

    egin{aligned}
    riangledown wL(w,b)&=-sum{x_iin M}y_ix_i
    riangledown_wL(w,b)&=-sum_{x_iin M}y_i
    end{aligned}

    [梯度下降法使参数向反方向变化,使用随机选出的误分类点的数据,根据提前设置好的学习率$eta$对$w,b$进行更新就可以了 ]

    egin{aligned}
    w& leftarrow w+eta y_ix_i
    b& leftarrow b+eta y_i
    end{aligned}

    [这样便可使损失函数不断减小,直到为0时就得到了可正确分类数据集的超平面。 2. ######对偶形式 在原始形式的学习算法中,可以看到每次更新$w,b$的数值都是选中的点$(x_i,y_i)$的线性组合,那么$w,b$必然可以用$(x_i,y_i)$线性表示,这样我们可以通过求解该线性组合的系数找到该超平面。对上节$w,b$的更新中,设总共修改N次,可将每次$w,b$增量表示为$alpha _iy_ix_i,alpha _iy_i$,其中$alpha = n_ieta$,假设$w_0=b_0=0$(这无关线性)。于是更新过程表示为 ]

    egin{aligned}
    w&=sum_ialpha _iy_ix_i
    b&=sum_i alpha _iy_i
    end{aligned}

    [这里$alpha _i=n_ieta _i$的含义是在该学习率下$(x_i,y_i)$在最后学习到的$w,b$中所贡献的权重,就是最后平面的$w,b$的系数,也是因该点误分类也进行更新的次数*$eta$。由此,感知机模型可由$alpha ,b$表出。 $$f(x)=sign(sum_jalpha _jy_jcdot x + b)$$在判断是否是误分类点时用 $$y_i(sum _jalpha _jy_jx_jcdot x_i + b)leqslant 0$$更新时 ]

    egin{aligned}
    alpha _i &leftarrow alpha _i +eta
    b &leftarrow b + eta y_i
    end{aligned}

    [可以看到该计算过程中训练数据全部由内积得到,因此可以提前将内积计算出来由矩阵存储,可以减少算法过程中的计算量,这是Gram矩阵。$$G= [x_i cdot x_j]_{N*N} ]

  • 相关阅读:
    vbscript错误代码及对应解释大全(希望还没过时)
    对象存储服务MinIO安装部署分布式及Spring Boot项目实现文件上传下载
    一道算法题,求更好的解法
    浅谈SQLite——实现与应用
    Linux网络协议栈(二)——套接字缓存(socket buffer)
    服务器开发入门——理解异步I/O
    理解MySQL——复制(Replication)
    线性时间排序算法
    Linux网络协议栈(一)——Socket入门(2)
    理解MySQL——索引与优化
  • 原文地址:https://www.cnblogs.com/breezezz/p/11151116.html
Copyright © 2011-2022 走看看