zoukankan      html  css  js  c++  java
  • adaboost实践

    数据来源《统计学习方法》(李航):

    x=[ 0,1,2,3,4,5,6,7,8,9]
    y=[1,1,1, -1,-1,-1, 1,1,1,-1];

    实现:

     1 # -*- coding: utf-8 -*-
     2 """
     3 Created on Sun Oct 14 18:52:18 2018
     4 @author: Administrator
     5 """
     6 #处理并获取数据
     7 import pandas as pd
     8 import math
     9 import numpy as np
    10 
    11 def getData():
    12     x=[ 0,1,2,3,4,5,6,7,8,9]
    13     y=[1,1,1, -1,-1,-1, 1,1,1,-1];
    14     z=[]
    15     for i in range(0,len(x)):
    16         z.append( [ x[i],y[i] ] );
    17     data=pd.DataFrame(z,columns=['x','y']);
    18     return data
    19 
    20 #基分类器
    21 def basicClassifier(Dweight,data):
    22     #采取x<v
    23     #阈值选择使分类误差最小的            
    24     threshold=0; #阈值
    25     error=1;     #分类误差率
    26     label=1;     # x<v 为正(负)类,
    27     for i in range( 0,len(data) ):
    28         tmpThreshold=data['x'][i];  #选择阈值
    29         tmpError1=0       #判断 label 为正类的误差
    30         tmpError2=0      #判断 label 为负类的误差 
    31         for j in range( 0,len(data) ):
    32             if( data['x'][j]<tmpThreshold ):
    33                 if( data['y'][j]==-1 ):      
    34                     tmpError1+=Dweight[j];# 如果label为正类
    35                 else:
    36                     tmpError2+=Dweight[j];# 如果label为负类
    37             else:
    38                 if( data['y'][j]==1 ):      
    39                     tmpError1+=Dweight[j];# label为正类
    40                 else:
    41                     tmpError2+=Dweight[j];# label为负类
    42         if( error>tmpError1 and tmpError1!=0 and tmpError1<tmpError2): 
    43             threshold=tmpThreshold;
    44             error=tmpError1;
    45             label=1 
    46         if( error>tmpError2 and tmpError1>tmpError2 and tmpError2!=0 ):
    47             threshold=tmpThreshold;    
    48             label=-1; 
    49             error=tmpError2;
    50     #求该基本分类器的权重       
    51     alpha = math.log((1-error)/error)/2 ;
    52     #更新数据权重
    53     NewDweight=[]
    54     sumAll=0 #规范化因子
    55     for i in range(0,len(data)):
    56         if( data['x'][i]< threshold  ):
    57             Gm=label
    58         else:
    59             Gm=label*(-1)
    60         vv=Dweight[i]*( math.e**(alpha*data['y'][i]*Gm*(-1)) ) 
    61         NewDweight.append(vv);
    62         sumAll+=vv; 
    63     NewDweight = np.array(NewDweight);
    64     NewDweight = NewDweight/sumAll;
    65     return threshold,label,alpha,list(NewDweight);
    66 
    67 #最后的分类器
    68 def adaboost(data):
    69     threshold=[];  #每个基本分类器的阈值
    70     label=[];      #每个基本分类器 x<v 类别
    71     Cweight=[];    #每个基本分类器的权重
    72     Dweight=[];    #每项数据权重
    73     for i in range(0,len(data)):
    74         Dweight.append(1/len(data));
    75     M=3;          # M个基本分类器组合
    76     for i in range(0,M):
    77         th, la, Cw ,Dweight= basicClassifier(Dweight,data)
    78         threshold.append(th)
    79         label.append(la)
    80         Cweight.append(Cw)
    81     return [ threshold, label,Cweight ]
    82 
    83 def classify(data, model):
    84     for i in range(0,len(data)):
    85         val=0;
    86         for j in range(0,len(model[0])):
    87             if( data['x'][i]< model[0][j] ):
    88                 val+=( model[1][j]*model[2][j] );
    89             else:
    90                 val+=( model[1][j]*model[2][j]*(-1) );
    91         if(val<0):
    92             print('-1')
    93         else:
    94             print('1');
    95 data = getData();  
    96 mo=adaboost(data);
    97 classify(data,mo)
  • 相关阅读:
    vi常用操作
    Python练习题
    Jmeter也能IP欺骗!
    mysql主从配置
    性能测试之mysql监控、优化
    Git 命令
    Chrome——F12 谷歌开发者工具详解
    Appscan
    微信群发红包抢红包设计测试用例
    MySQL基础篇(1)SQL基础
  • 原文地址:https://www.cnblogs.com/z-bear/p/9789608.html
Copyright © 2011-2022 走看看