zoukankan      html  css  js  c++  java
  • adaboost实践

    数据来源《统计学习方法》(李航):

    x=[ 0,1,2,3,4,5,6,7,8,9]
    y=[1,1,1, -1,-1,-1, 1,1,1,-1];

    实现:

     1 # -*- coding: utf-8 -*-
     2 """
     3 Created on Sun Oct 14 18:52:18 2018
     4 @author: Administrator
     5 """
     6 #处理并获取数据
     7 import pandas as pd
     8 import math
     9 import numpy as np
    10 
    11 def getData():
    12     x=[ 0,1,2,3,4,5,6,7,8,9]
    13     y=[1,1,1, -1,-1,-1, 1,1,1,-1];
    14     z=[]
    15     for i in range(0,len(x)):
    16         z.append( [ x[i],y[i] ] );
    17     data=pd.DataFrame(z,columns=['x','y']);
    18     return data
    19 
    20 #基分类器
    21 def basicClassifier(Dweight,data):
    22     #采取x<v
    23     #阈值选择使分类误差最小的            
    24     threshold=0; #阈值
    25     error=1;     #分类误差率
    26     label=1;     # x<v 为正(负)类,
    27     for i in range( 0,len(data) ):
    28         tmpThreshold=data['x'][i];  #选择阈值
    29         tmpError1=0       #判断 label 为正类的误差
    30         tmpError2=0      #判断 label 为负类的误差 
    31         for j in range( 0,len(data) ):
    32             if( data['x'][j]<tmpThreshold ):
    33                 if( data['y'][j]==-1 ):      
    34                     tmpError1+=Dweight[j];# 如果label为正类
    35                 else:
    36                     tmpError2+=Dweight[j];# 如果label为负类
    37             else:
    38                 if( data['y'][j]==1 ):      
    39                     tmpError1+=Dweight[j];# label为正类
    40                 else:
    41                     tmpError2+=Dweight[j];# label为负类
    42         if( error>tmpError1 and tmpError1!=0 and tmpError1<tmpError2): 
    43             threshold=tmpThreshold;
    44             error=tmpError1;
    45             label=1 
    46         if( error>tmpError2 and tmpError1>tmpError2 and tmpError2!=0 ):
    47             threshold=tmpThreshold;    
    48             label=-1; 
    49             error=tmpError2;
    50     #求该基本分类器的权重       
    51     alpha = math.log((1-error)/error)/2 ;
    52     #更新数据权重
    53     NewDweight=[]
    54     sumAll=0 #规范化因子
    55     for i in range(0,len(data)):
    56         if( data['x'][i]< threshold  ):
    57             Gm=label
    58         else:
    59             Gm=label*(-1)
    60         vv=Dweight[i]*( math.e**(alpha*data['y'][i]*Gm*(-1)) ) 
    61         NewDweight.append(vv);
    62         sumAll+=vv; 
    63     NewDweight = np.array(NewDweight);
    64     NewDweight = NewDweight/sumAll;
    65     return threshold,label,alpha,list(NewDweight);
    66 
    67 #最后的分类器
    68 def adaboost(data):
    69     threshold=[];  #每个基本分类器的阈值
    70     label=[];      #每个基本分类器 x<v 类别
    71     Cweight=[];    #每个基本分类器的权重
    72     Dweight=[];    #每项数据权重
    73     for i in range(0,len(data)):
    74         Dweight.append(1/len(data));
    75     M=3;          # M个基本分类器组合
    76     for i in range(0,M):
    77         th, la, Cw ,Dweight= basicClassifier(Dweight,data)
    78         threshold.append(th)
    79         label.append(la)
    80         Cweight.append(Cw)
    81     return [ threshold, label,Cweight ]
    82 
    83 def classify(data, model):
    84     for i in range(0,len(data)):
    85         val=0;
    86         for j in range(0,len(model[0])):
    87             if( data['x'][i]< model[0][j] ):
    88                 val+=( model[1][j]*model[2][j] );
    89             else:
    90                 val+=( model[1][j]*model[2][j]*(-1) );
    91         if(val<0):
    92             print('-1')
    93         else:
    94             print('1');
    95 data = getData();  
    96 mo=adaboost(data);
    97 classify(data,mo)
  • 相关阅读:
    Repeater 双向排序
    将具有固定格式的text 类型中的数据分离出来的一种方法
    ASP.NET 快速构建应用程序页面主框架
    2分分页处理存储过程通用存储过程
    C#3.0之匿名类型
    常用的js收集
    用CSS实现DataGird滚动而表头不动
    Lucene.Net 创建索引和检索
    Lucene.net 实现全文搜索
    SQL 中操作日期的几个函数
  • 原文地址:https://www.cnblogs.com/z-bear/p/9789608.html
Copyright © 2011-2022 走看看