这一节主要介绍一下XGBoost算法在CPU/GPU版本下的代码编写基本流程,主要分为以下几个部分:
- 构造训练集/验证集
- 算法参数设置
- XGBoost模型训练/验证
- 模型预测
主要面对的任务场景是多分类任务,下一节再说回归任务;
另外,除上述几个部分外,会涉及到sklearn用于加载数据集以及最后的模型预测的评价指标计算;
导入使用到的库:
import time
import xgboost as xgb
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, classification
from libs.xgboost_plot import plot_training_merror
1. 构造数据集/验证集
使用sklearn
导入数据集,并进一步拆分成训练集、验证集;
# 使用sklearn加载数据集,并进一步拆分
digits = datasets.load_digits()
data, labels = digits.data, digits.target
x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=7)
print("x_train: {}, x_test: {}, classes: {}".format(x_train.shape, x_test.shape, len(set(labels))))
构造XGBoost算法需要的输入格式:
dtrain = xgb.DMatrix(x_train, y_train) # 训练集
dtest = xgb.DMatrix(x_test, y_test) # 验证集
evals = [(dtrain, 'train'), (dtest, 'val')] # 训练过程中进行验证
2. 算法参数设置
算法模型参数设置,详情见:XGBoost Parameters
params = {
'tree_method': "gpu_hist",
'booster': 'gbtree',
'objective': 'multi:softmax',
'num_class': 10,
'max_depth': 6,
'eval_metric': 'merror',
'eta': 0.01,
'verbosity': 0,
'gpu_id': 0
}
简单介绍以下:
tree_method
,gpu_hist
表示使用GPU运算,影响的可以使用hist
利用CPU计算;objective
,目标函数num_class
,类别数量,配合multi:softmax
使用;
3. XGBoost模型训练/验证
模型训练、保存模型、绘制merror
图像;
s_time = time.time()
train_res = {}
model = xgb.train(params, dtrain, num_boost_round=100,
evals=evals,
evals_result=train_res)
print("模型训练耗时: {}".format(time.time() - s_time))
# 模型保存
save_path = "./saved_model/model.model"
model.save_model(save_path)
# train/val的merror绘图
merror_img_path = "./test/error.png"
plot_training_merror(train_res, merror_img_path)
4. 模型预测
模型预测,打印预测结果
pred_data = model.predict(dtest)
res = classification_report(y_test, pred_data)
print(res)
输出如下:
precision recall f1-score support
0 1.00 0.95 0.98 43
1 0.86 1.00 0.92 42
2 0.98 1.00 0.99 40
3 0.89 0.97 0.93 34
4 0.92 0.89 0.90 37
5 0.93 0.96 0.95 28
6 0.96 0.93 0.95 28
7 0.86 0.94 0.90 33
8 0.95 0.81 0.88 43
9 0.96 0.81 0.88 32
accuracy 0.93 360
macro avg 0.93 0.93 0.93 360
weighted avg 0.93 0.93 0.93 360
5. 结语
XGBoost框架最基本的使用,也就是这个流程了:
- 构建数据集
- 参数设置
- 模型训练、保存、预测;
当然,在实际应用中,每一个步骤中,都存在很多的细节值得深究,也必须深究;
不然很难做到知其然知其所以然,对于实际的问题,也很难获得一个很好的结果;