zoukankan      html  css  js  c++  java
  • Xgboost 两种使用方式

    原生形式使用Xgboost(import xgboost as xgb)

    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    import xgboost as xgb
    import numpy as np
    from sklearn.metrics import precision_score, recall_score
    
    # 加载数据
    iris = datasets.load_iris()
    X = iris.data
    y = iris.target
    
    # 划分训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42)
    print("Train data length:", len(X_train))
    print("Test data length:", len(X_test))
    
    # 转换为DMatrix数据格式
    dtrain = xgb.DMatrix(X_train, label=y_train)
    dtest = xgb.DMatrix(X_test, label=y_test)
    
    # 设置参数
    parameters = {
        'eta': 0.3,
        'silent': True,  # option for logging
        'objective': 'multi:softprob',  # error evaluation for multiclass tasks
        'num_class': 3,  # number of classes to predic
        'max_depth': 3  # depth of the trees in the boosting process
    }
    num_round = 20  # the number of training iterations
    
    # 模型训练
    bst = xgb.train(parameters, dtrain, num_round)
    
    # 模型预测
    preds = bst.predict(dtest)
    
    print(preds[:5])
    
    # 选择表示最高概率的列
    best_preds = np.asarray([np.argmax(line) for line in preds])
    print(best_preds)
    
    # 模型评估
    print(precision_score(y_test, best_preds, average='macro'))  # 精准率
    print(recall_score(y_test, best_preds, average='macro'))  # 召回率
    
    
    

    Sklearn接口形式使用Xgboost(from xgboost import XGBClassifier)

    from sklearn import datasets
    from sklearn.model_selection import train_test_split
    from xgboost import XGBClassifier
    from sklearn.metrics import precision_score, recall_score
    
    # 加载数据
    iris = datasets.load_iris()
    X = iris.data
    y = iris.target
    
    # 划分训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=42)
    print("Train data length:", len(X_train))
    print("Test data length:", len(X_test))
    
    # 模型训练
    model = XGBClassifier(
            learning_rate=0.01,
            n_estimators=3000,
            max_depth=4,
            min_child_weight=5,
            gamma=0.1,
            subsample=0.8,
            colsample_bytree=0.8,
            reg_alpha=1,
            objective='binary:logistic',
            nthread=8,
            scale_pos_weight=1,
            seed=27
        )
    model.fit(X_train, y_train)
    
    # 预测
    y_pred = model.predict(X_test)
    
    # 模型评估
    print(precision_score(y_test, y_pred, average='macro'))  # 精准率
    print(recall_score(y_test, y_pred, average='macro'))  # 召回率
    
  • 相关阅读:
    手机号码 正则表达式
    邮政编码的正则表达式
    对象为null,调用非静态方法产生空指针异常
    文件找不到异常(FileNotFoundException)
    数组下标越界异常解决方法
    空指针异常的解决方法
    需求:打印九九乘法表
    创建简单线程
    ·博客作业06--图
    博客作业05--查找
  • 原文地址:https://www.cnblogs.com/chenxiangzhen/p/10894143.html
Copyright © 2011-2022 走看看