zoukankan      html  css  js  c++  java
  • AdaBoost--从原理到实现(Code:Python)


           本文对原文有修改,若有疑虑,请移步原作者.  原文链接:blog.csdn.net/dark_scope/article/details/14103983

             集成方法在函数模型上等价于一个多层神经网络,两种常见的集成方法为Adaboost模型和RandomTrees模型。其中随机森林可被视为前馈神经网络,而Adaboost模型则等价于一个反馈型多层神经网络。

    一.引入

             对于Adaboost,可以说是久闻大名,据说在Deep Learning出来之前,SVM和Adaboost是效果最好的 两个算法,而Adaboost是提升树(boosting tree),所谓“ 提升树 ” 就是把“弱学习算法”提升(boost)为“强学习算法”(语自《统计学习方法》),而其中最具代表性的也就是Adaboost了,貌似Adaboost的结构还和Neural Network有几分神似,我倒没有深究过,不知道是不是有什么干货。

    二.过程

            (from PRML)

             这就是Adaboost的结构,最后的分类器YM是由数个弱分类器(weak classifier)组合而成的,相当于最后m个弱分类器来投票决定分类,而且每个弱分类器的“话语权”α不一样。

             这里阐述下算法的具体过程:

            1.初始化所有训练样例的权重为1 / N,其中N是样例数

            2.for m=1,……M:

                   a).训练弱分类器ym(),使其最小化权重误差函数(weighted error function)

               

                   b)接下来计算该弱分类器的话语权α:

                                                     

                   c)更新权重:

                                               

                                 其中Zm:

                                                     

                                 是规范化因子,使所有w的和为1。(这里公式稍微有点乱)

               3.得到最后的分类器:

                                  

    三.原理


                 可以看到整个过程就是和最上面那张图一样,前一个分类器改变权重w,同时组成最后的分类器
                 如果一个训练样例 在前一个分类其中被误分,那么它的权重会被加重,相应地,被正确分类的样例的权重会降低
                 使得下一个分类器 会更在意被误分的样例,那么其中那些α和w的更新是怎么来的呢?
                 下面我们从前项分步算法模型的角度来看看Adaboost:
                 直接将前项分步加法模型具体到adaboost上:
                                      
                  其中 fm是前m个分类器的结合
                                       
                  此时我们要最小化E,同时要考虑α和yl,
                  但现在我们假设前m-1个α和y都已经fixed了:那么
                                       
                   其中,可以被看做一个常量,因为它里面没有αm和ym:
                   接下来:
                                       
                   其中Tm表示正分类的集合,Mm表示误分类的集合,这一步其实就是把上面那个式子拆开,没什么复杂的东西
                   然后就是找ym了,就是最小化下式的过程,其实就是我们训练弱分类器
                                                    
                   有了ym,α也就可以找了,然后继续就可以找到更新w的公式了(注意这里得到的w公式是没有加规范化因子Z的公式,为了计算方便我们加了个Z进去)
                   因为这里算出来直接就是上面过程里的公式,就不再赘述了,有兴趣你可以自己算一算         

       

    四.实现

                   终于到实现了,本次实现代码基本基于《统计学习方法》,比如有些符号(弱分类器是G(x),训练样例的目标是y而不是上文所述的t)差异
                   所有的代码你可以在我写的toy toolkit里面找到:DML ( 你都看到这了,给个star好不好大哭 )
                
        # coding: UTF-8  
        from __future__ import division  
        import numpy as np  
        import scipy as sp  
        from weakclassify import WEAKC  
        from dml.tool import sign  
        class ADABC:  
            def __init__(self,X,y,Weaker=WEAKC):  
                ''''' 
                    Weaker is a class of weak classifier 
                    It should have a    train(self.W) method pass the weight parameter to train 
                                        pred(test_set) method which return y formed by 1 or -1 
                    see detail in <统计学习方法> 
                '''  
                self.X=np.array(X)  
                self.y=np.array(y)  
                self.Weaker=Weaker  
                self.sums=np.zeros(self.y.shape)  
                self.W=np.ones((self.X.shape[1],1)).flatten(1)/self.X.shape[1]  
                self.Q=0  
                #print self.W  
            def train(self,M=4):  
                ''''' 
                    M is the maximal Weaker classification 
                '''  
                self.G={}  
                self.alpha={}  
                for i in range(M):  
                    self.G.setdefault(i)  
                    self.alpha.setdefault(i)  
                for i in range(M):  
                    self.G[i]=self.Weaker(self.X,self.y)  
                    e=self.G[i].train(self.W)  
                    #print self.G[i].t_val,self.G[i].t_b,e  
                    self.alpha[i]=1/2*np.log((1-e)/e)  
                    #print self.alpha[i]  
                    sg=self.G[i].pred(self.X)  
                    Z=self.W*np.exp(-self.alpha[i]*self.y*sg.transpose())  
                    self.W=(Z/Z.sum()).flatten(1)  
                    self.Q=i  
                    #print self.finalclassifer(i),'==========='  
                    if self.finalclassifer(i)==0:  
          
                        print i+1," weak classifier is enough to  make the error to 0"  
                        break  
            def finalclassifer(self,t):  
                ''''' 
                    the 1 to t weak classifer come together 
                '''  
                self.sums=self.sums+self.G[t].pred(self.X).flatten(1)*self.alpha[t]  
                #print self.sums  
                pre_y=sign(self.sums)  
                #sums=np.zeros(self.y.shape)  
                #for i in range(t+1):  
                #   sums=sums+self.G[i].pred(self.X).flatten(1)*self.alpha[i]  
                #   print sums  
                #pre_y=sign(sums)  
                t=(pre_y!=self.y).sum()  
                return t  
            def pred(self,test_set):  
                sums=np.zeros(self.y.shape)  
                for i in range(self.Q+1):  
                    sums=sums+self.G[i].pred(self.X).flatten(1)*self.alpha[i]  
                    #print sums  
                pre_y=sign(sums)  
                return pre_y  



    看train里面的过程和上文 阐述的一模一样,finalclassifier()函数是用来判断是否已经无误分类的点 的
    当然这里用的Weak Classifier是比较基础的Decision Stump,是根据x>v和x<v来分类的,这个代码稍微烦一点,就不贴到这里了,在DML里也有
    先试验下《统计学习方法》里面那个最简单的例子:

    可以看到也是三个分类器就没有误分点了,权值的选择也是差不多的
    其中后面那个-1 表示大于threshold分为负类,小于分为正类。1则相反


    加一些其它数据试试:

    结果:
      
    我们把图画出来就是:

    基本还是正确的,这是四个子分类器的图,不是最后总分类器的图啊~~~
    (实验的代码你也可以在DML里面找到,你都看到这了,给个star好不好~~~~~大笑

    Reference:

          【1】 《Pattern Recognition And Machine Learning》
          【2】 《统计学习方法》
  • 相关阅读:
    Atitit nosql的艺术 attilax著作 目录 1. 1.5NoSQL数据库的类型 1 1.1. 1.5.1键值(Key/Value)存储 1 1.2. 1.5.2面向文档的数据库 1 1
    Atitit 常见信息化系统类别erp mes crm cms oa 目录 1.  企业资源规划(ERP)、客户关系管理(CRM)、协同管理系统(CMS)是企业信息化的三大代表之作 1 2. 概
    Atitit 信息管理概论 艾提拉总结 信息的采集 信息格式转换 信息整合 信息的tag标注 信息的结构化 信息检索,,索引 压缩 信息分析 汇总 第1章 信息管理的基本概念 第
    Atitit 产品化法通则 目录 1. 何谓软件产品化? 1 2. 产品化优点 vs 项目化 2 2.1. 软件复用率提高 2 2.2. ,项目化交付 2 2.3. 维护成本高 2 3. 产品金字塔
    Atitit 人工智能 统计学 机器学习的相似性 一些文摘收集 没有人工智能这门功课,人工智能的本质是统计学和数学,就是通过机器对数据的识别、计算、归纳和学习,然后做出下一步判断和决策的科学
    Atitit mybatis spring整合。读取spring、yml、文件的mysql url 步骤,读取yml,文件,使用ongl定位到url pwd usr 读取mybatis模板配置,
    关于一个大型web系统构架图的理解
    关于《王福朋petshop4.0视频教程》下载的更新
    不完全接触Node.js
    毕业设计那些事
  • 原文地址:https://www.cnblogs.com/wishchin/p/9200050.html
Copyright © 2011-2022 走看看