zoukankan      html  css  js  c++  java
  • AdaBoost

    coding=utf-8

    python 3.5

    '''
    Created on 2017年11月27日

    @author: Scorpio.Lu
    '''

    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.ensemble import AdaBoostClassifier
    from sklearn.tree import DecisionTreeClassifier
    参考网址:https://louisscorpio.github.io/2017/11/28/代码实战之AdaBoost/
    from sklearn.datasets import make_gaussian_quantiles
    from sklearn.model_selection import train_test_split
    from sklearn import metrics
    import pandas as pd

    用make_gaussian_quantiles生成多组多维正态分布的数据

    这里生成2维正态分布,设定样本数1000,协方差2

    x1,y1=make_gaussian_quantiles(cov=2., n_samples=200, n_features=4, n_classes=2, shuffle=True, random_state=1)

    #为了增加样本分布的复杂度,再生成一个数据分布

    x2,y2=make_gaussian_quantiles(mean=(3,3,3,3), cov=1.5, n_samples=300, n_features=4, n_classes=2, shuffle=True, random_state=1)

    #合并

    X=np.vstack((x1,x2))

    y=np.hstack((y1,1-y2))

    第一步构建数据

    candidates = {'gmat': [780,750,690,710,680,730,690,720,740,690,610,690,710,680,770,610,580,650,540,590,620,600,550,550,570,670,660,580,650,660,640,620,660,660,680,650,670,580,590,690],
    'gpa': [4,3.9,3.3,3.7,3.9,3.7,2.3,3.3,3.3,1.7,2.7,3.7,3.7,3.3,3.3,3,2.7,3.7,2.7,2.3,3.3,2,2.3,2.7,3,3.3,3.7,2.3,3.7,3.3,3,2.7,4,3.3,3.3,2.3,2.7,3.3,1.7,3.7],
    'work_experience': [3,4,3,5,4,6,1,4,5,1,3,5,6,4,3,1,4,6,2,3,2,1,4,1,2,6,4,2,6,5,1,2,4,6,5,1,2,1,4,5],
    'admitted': [1,1,1,1,1,1,0,1,1,0,0,1,1,1,1,0,0,1,0,0,0,0,0,0,0,1,1,0,1,1,0,0,1,1,1,0,0,0,0,1] }
    df = pd.DataFrame(candidates,columns= ['gmat', 'gpa','work_experience','admitted'])
    X = df[['gmat', 'gpa','work_experience']]
    y = df['admitted']

    设定弱分类器CART

    weakClassifier=DecisionTreeClassifier(max_depth=1)
    X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.25,random_state=0)

    构建模型。

    clf=AdaBoostClassifier(base_estimator=weakClassifier,algorithm='SAMME',n_estimators=1000,learning_rate=0.01)
    clf.fit(X_train, y_train)
    y_pred=clf.predict(X_test)
    print(y_pred)
    print('精度: ',metrics.accuracy_score(y_test, y_pred))

  • 相关阅读:
    CentOS之文件搜索命令locate
    CentOs之链接命令
    CentOs之常见目录作用介绍
    centOs之目录处理命令
    Query注解及方法限制
    Repository接口
    OkHttp和Volley对比
    Base64加密与MD5的区别?
    支付宝集成
    Android 中 非对称(RSA)加密和对称(AES)加密
  • 原文地址:https://www.cnblogs.com/131415-520/p/11776317.html
Copyright © 2011-2022 走看看