zoukankan      html  css  js  c++  java
  • GBDT实战

    数据集地址

    基于sklearn接口的分类

    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.metrics import accuracy_score
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import OneHotEncoder
    from sklearn.externals import joblib
    import numpy as np
    
    
    # 以分隔符,读取文件,得到的是一个二维列表
    iris = np.loadtxt('iris.data', dtype=str, delimiter=',', unpack=False, encoding='utf-8')
    
    # 前4列是特征
    data = iris[:, :4].astype(np.float)
    # 最后一列是标签,我们将其转换为二维列表
    target = iris[:, -1][:, np.newaxis]
    
    # 对标签进行onehot编码后还原成数字
    enc = OneHotEncoder()
    target = enc.fit_transform(target).astype(np.int).toarray()
    target = [list(oh).index(1) for oh in target]
    
    # 划分训练数据和测试数据
    X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=1)
    
    # 模型训练
    gbdt = GradientBoostingClassifier(n_estimators=3000, max_depth=2, min_samples_split=2, learning_rate=0.1)
    gbdt.fit(X_train, y_train)
    
    # 模型存储
    joblib.dump(gbdt, 'gbdt_model.pkl')
    # 模型加载
    gbdt = joblib.load('gbdt_model.pkl')
    
    # 模型预测
    y_pred = gbdt.predict(X_test)
    
    # 模型评估
    print('The accuracy of prediction is:', accuracy_score(y_test, y_pred))
    
    # 特征重要度
    print('Feature importances:', list(gbdt.feature_importances_))
    

    结果

    The accuracy of prediction is: 0.9666666666666667
    Feature importances: [0.002148238569679191, 0.0046703830672789074, 0.33366676380518245, 0.6595146145578594]

    进行超参数搜索

    from sklearn.ensemble import GradientBoostingClassifier
    from sklearn.model_selection import train_test_split, GridSearchCV
    from sklearn.preprocessing import OneHotEncoder
    import numpy as np
    
    
    # 以分隔符,读取文件,得到的是一个二维列表
    iris = np.loadtxt('iris.data', dtype=str, delimiter=',', unpack=False, encoding='utf-8')
    
    # 前4列是特征
    data = iris[:, :4].astype(np.float)
    # 最后一列是标签,我们将其转换为二维列表
    target = iris[:, -1][:, np.newaxis]
    
    # 对标签进行onehot编码后还原成数字
    enc = OneHotEncoder()
    target = enc.fit_transform(target).astype(np.int).toarray()
    target = [list(oh).index(1) for oh in target]
    
    # 划分训练数据和测试数据
    X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=1)
    
    
    estimator = GradientBoostingClassifier()
    param_grid = {
        'learning_rate': [0.05, 0.1,],
        'n_estimators': [2000, 3000],
        'max_depth': [2, 3, ],
        'min_samples_split': [2, 3],
    }
    gbm = GridSearchCV(estimator, param_grid)
    gbm.fit(X_train, y_train)
    print('Best parameters found by grid search are:', gbm.best_params_)
    

    结果:

    Best parameters found by grid search are: {'learning_rate': 0.05, 'max_depth': 2, 'min_samples_split': 2, 'n_estimators': 2000}

    基于sklearn接口的回归

    from sklearn.datasets import make_regression
    from sklearn.ensemble import GradientBoostingRegressor
    from sklearn.model_selection import train_test_split
    import lightgbm as lgb
    from sklearn.metrics import mean_absolute_error
    
    X, y = make_regression(n_samples=100, n_features=1, noise=20)
    
    # 切分训练集、测试集
    train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.25, random_state=1)
    
    # 调用GBDT模型,使用训练集数据进行训练(拟合)
    my_model = GradientBoostingRegressor(
      loss='ls'
    , learning_rate=0.1
    , n_estimators=100
    , subsample=1
    , min_samples_split=2
    , min_samples_leaf=1
    , max_depth=3
    , init=None
    , random_state=None
    , max_features=None
    , alpha=0.9
    , verbose=0
    , max_leaf_nodes=None
    , warm_start=False
    )
    
    my_model.fit(train_X, train_y)
    
    # 使用模型对测试集数据进行预测
    predictions = my_model.predict(test_X)
    
    # 对模型的预测结果进行评判(平均绝对误差)
    print("Mean Absolute Error : " + str(mean_absolute_error(predictions, test_y)))
    

    结果:

    Mean Absolute Error : 19.30372190886058

  • 相关阅读:
    TestLink学习六:TestLink1.9.13工作使用小结
    TestLink学习五:TestLink1.9.13和JIRA6.3.6的集成
    TestLink学习四:TestLink1.9.13使用说明
    TestLink学习三:发送邮件的两种配置方法
    TestLink学习二:Windows搭建TestLink环境
    TestLink学习一:Windows搭建Apache+MySQL+PHP环境
    Python:Ubuntu上使用pip安装opencv-python出现错误
    Python:Ubuntu上出现错误 Could not load dynamic library 'libnvinfer.so.6' / 'libnvinfer_plugin.so.6'
    mybatis-generator二次开发总结
    动态代理
  • 原文地址:https://www.cnblogs.com/xiximayou/p/14421561.html
Copyright © 2011-2022 走看看