zoukankan      html  css  js  c++  java
  • LightGBM

    LightGBM:

    属于boosting算法中的一种,全称为轻量的梯度提升机(Light Gradient Boosting Machine),由微软于2017年开源出来的一款SOTA Boosting算法框架。

    跟XGBoost一样,LightGBM也是GBDT算法框架的一种工程实现,不过更加快速和高效

    安装的时候要注意,如果遇到问题,可以参考下面这个巧妙处理方法

    https://blog.csdn.net/sinat_36226553/article/details/106109821  (神奇的安装lightgbm时的解决方法)

    使用iris数据集为例的代码如下:

    import pandas as pd
    import lightgbm as lgb
    from sklearn.metrics import mean_squared_error
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    from sklearn.datasets import make_classification
    
    # 导入数据
    iris = load_iris()
    data = iris.data
    target = iris.target
    X_train, X_test, y_train, y_test =train_test_split(data, target, test_size=0.2)
    
    # 创建模型
    gbm = lgb.LGBMRegressor(objective='regression',
                            num_leaves=31,
                            learning_rate=0.05,
                            n_estimators=20)
    # 模型训练
    gbm.fit(X_train, y_train,eval_set=[(X_test, y_test)],eval_metric='l1',early_stopping_rounds=5)
    # 预测测试集
    y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration_)
    # 模型评估
    print(mean_squared_error(y_test, y_pred) ** 0.5)
    # 查看特征重要性
    print(list(gbm.feature_importances_))
    #LightGBM回归模型五折交叉验证训练的代码
    import time
    import numpy as np
    import pandas as pd
    import lightgbm as lgb
    from sklearn.model_selection import KFold
    from sklearn.metrics import mean_squared_error
    
    # 训练特征,使用时label要换为实际标签名称
    features = [f for f in df.columns if f not in [label]]
    
    # 自定义模型评估方法
    def evalerror(pred, df):
        label = df.get_label().values.copy()
        score = mean_squared_error(label, pred)*0.5
        return ('mse', score, False)
    
    # 指定超参数
    params = {
        'learning_rate': 0.01,
        'boosting_type': 'gbdt',
        'objective': 'regression',
        'metric': 'mse',
        'sub_feature': 0.7,
        'num_leaves': 60,
        'colsample_bytree': 0.7,
        'feature_fraction': 0.7,
        'min_data': 100,
        'min_hessian': 1,
        'verbose': -1,
    }
    
    t0 = time.time()
    train_preds = np.zeros(train.shape[0])
    
    # 五折交叉验证训练
    kf = KFold(n_splits=5, shuffle=True, random_state=43)
    for i, (train_index, valid_index) in enumerate(kf.split(train)):
        print('train for {} epoch...'.format(i))
        train2 = train.iloc[train_index]
        valid2 = train.iloc[valid_index]
        lgb_train = lgb.Dataset(train2[features], train2['total_cost'], categorical_feature=['hy', 'sex', 'pay_type'])
        lgb_valid = lgb.Dataset(valid2[features], valid2['total_cost'], categorical_feature=['hy', 'sex', 'pay_type'])
        model = lgb.train(params,
                        lgb_train,
                        num_boost_round=3000,
                        valid_sets=lgb_valid,
                        verbose_eval=300,
                        feval=evalerror,
                        early_stopping_rounds=100)
        # 特征重要性排序
        feat_importance = pd.Series(model.feature_importance(), index=features).sort_values(ascending=False)
        train_preds[valid_index] += model.predict(valid2[features], num_iteration=model.best_iteration)
    
    print('Validset score: {}'.format(mean_squared_error(labels, train_preds)*0.5))
    print('Cross Validation spend {} seconds'.format(time.time() - t0))

     参数补充:

    max_depth, default=-1, type=int,树的最大深度限制,防止过拟合
    min_data_in_leaf, default=20, type=int, 叶子节点最小样本数,防止过拟合
    feature_fraction, default=1.0, type=double, 0.0 < feature_fraction < 1.0,随机选择特征比例,加速训练及防止过拟合
    feature_fraction_seed, default=2, type=int,随机种子数,保证每次能够随机选择样本的一致性
    bagging_fraction, default=1.0, type=double, 类似随机森林,每次不重采样选取数据
    lambda_l1, default=0, type=double, L1正则
    lambda_l2, default=0, type=double, L2正则
    min_split_gain, default=0, type=double, 最小切分的信息增益值
    top_rate, default=0.2, type=double,大梯度树的保留比例
    other_rate, default=0.1, type=int,小梯度树的保留比例
    min_data_per_group, default=100, type=int,每个分类组的最小数据量
    max_cat_threshold, default=32, type=int,分类特征的最大阈值

  • 相关阅读:
    EasyExcel无法用转换器或者注解将java字段写入为excel的数值格式
    IE浏览器报400错误:Invalid character found in the request target. The valid characters are defined in RFC 7230 and RFC 3986
    list集合根据字段分组统计转换成map
    博客调网易云歌单JS
    如何一次性add library to classpath
    有趣的统计数据表格显示
    span标签的巧用
    "错误: 找不到或无法加载主类"解决办法
    通过改变注入方式以消除警告
    day17--作业
  • 原文地址:https://www.cnblogs.com/cgmcoding/p/13267014.html
Copyright © 2011-2022 走看看