zoukankan      html  css  js  c++  java
  • XGBoost使用教程(与sklearn一起使用)二

    一、导入必要的工具包
    # 运行 xgboost安装包中的示例程序
    from xgboost import XGBClassifier

    # 加载LibSVM格式数据模块
    from sklearn.datasets import load_svmlight_file
    from sklearn.metrics import accuracy_score

    from matplotlib import pyplot
    二、数据读取
    scikit-learn支持多种格式的数据,包括LibSVM格式数据
    XGBoost可以加载libsvm格式的文本数据,libsvm的文件格式(稀疏特征)如下:
    1  101:1.2 102:0.03
    0  1:2.1 10001:300 10002:400
    ...
    每一行表示一个样本,第一行的开头的“1”是样本的标签。“101”和“102”为特征索引,'1.2'和'0.03' 为特征的值。
    在两类分类中,用“1”表示正样本,用“0” 表示负样本。也支持[0,1]表示概率用来做标签,表示为正样本的概率。
    下面的示例数据需要我们通过一些蘑菇的若干属性判断这个品种是否有毒。
    UCI数据描述:http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/ ,
    每个样本描述了蘑菇的22个属性,比如形状、气味等等(加工成libsvm格式后变成了126维特征),
    然后给出了这个蘑菇是否可食用。其中6513个样本做训练,1611个样本做测试。

    数据下载地址:http://download.csdn.net/download/u011630575/10266113

    # read in data,数据在xgboost安装的路径下的demo目录,现在copy到代码目录下的data目录
    my_workpath = './data/'
    X_train,y_train = load_svmlight_file(my_workpath + 'agaricus.txt.train')
    X_test,y_test = load_svmlight_file(my_workpath + 'agaricus.txt.test')

    print(X_train.shape)
    print (X_test.shape)
    三、训练参数设置

    max_depth: 树的最大深度。缺省值为6,取值范围为:[1,∞]
    eta:为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重。
    eta通过缩减特征的权重使提升计算过程更加保守。缺省值为0.3,取值范围为:[0,1]
    silent:取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0
    objective: 定义学习任务及相应的学习目标,“binary:logistic” 表示二分类的逻辑回归问题,输出为概率。

    其他参数取默认值。
    四、训练模型

    # 设置boosting迭代计算次数
    num_round = 2


    bst =XGBClassifier(max_depth=2, learning_rate=1, n_estimators=num_round, 
                       silent=True, objective='binary:logistic') #sklearn api


    bst.fit(X_train, y_train)
    XGBoost预测的输出是概率。这里蘑菇分类是一个二类分类问题,输出值是样本为第一类的概率。
    我们需要将概率值转换为0或1。

    train_preds = bst.predict(X_train)
    train_predictions = [round(value) for value in train_preds]

    train_accuracy = accuracy_score(y_train, train_predictions)
    print ("Train Accuary: %.2f%%" % (train_accuracy * 100.0))
    五、测试

    模型训练好后,可以用训练好的模型对测试数据进行预测
    XGBoost预测的输出是概率,输出值是样本为第一类的概率。我们需要将概率值转换为0或1。

    # make prediction
    preds = bst.predict(X_test)
    predictions = [round(value) for value in preds]

    test_accuracy = accuracy_score(y_test, predictions)
    print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))
    六、代码整理

    # coding:utf-8
    # 运行 xgboost安装包中的示例程序
    from xgboost import XGBClassifier

    # 加载LibSVM格式数据模块
    from sklearn.datasets import load_svmlight_file
    from sklearn.metrics import accuracy_score

    from matplotlib import pyplot

    # read in data,数据在xgboost安装的路径下的demo目录,现在copy到代码目录下的data目录
    my_workpath = './data/'
    X_train,y_train = load_svmlight_file(my_workpath + 'agaricus.txt.train')
    X_test,y_test = load_svmlight_file(my_workpath + 'agaricus.txt.test')

    print(X_train.shape)
    print(X_test.shape)

    # 设置boosting迭代计算次数
    num_round = 2

    #bst = XGBClassifier(**params)
    #bst = XGBClassifier()
    bst =XGBClassifier(max_depth=2, learning_rate=1, n_estimators=num_round,
    silent=True, objective='binary:logistic')

    bst.fit(X_train, y_train)

    train_preds = bst.predict(X_train)
    train_predictions = [round(value) for value in train_preds]

    train_accuracy = accuracy_score(y_train, train_predictions)
    print ("Train Accuary: %.2f%%" % (train_accuracy * 100.0))

    # make prediction
    preds = bst.predict(X_test)
    predictions = [round(value) for value in preds]

    test_accuracy = accuracy_score(y_test, predictions)
    print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))
    ---------------------
    作者:鹤鹤有明
    来源:CSDN
    原文:https://blog.csdn.net/u011630575/article/details/79421053
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    Leetcode: K-th Smallest in Lexicographical Order
    Leetcode: Minimum Number of Arrows to Burst Balloons
    Leetcode: Minimum Moves to Equal Array Elements
    Leetcode: Number of Boomerangs
    Leetcode: Arranging Coins
    Leetcode: Path Sum III
    Leetcode: All O`one Data Structure
    Leetcode: Find Right Interval
    Leetcode: Non-overlapping Intervals
    Socket网络编程--简单Web服务器(3)
  • 原文地址:https://www.cnblogs.com/tan2810/p/11154725.html
Copyright © 2011-2022 走看看