zoukankan      html  css  js  c++  java
  • XGBoost 输出特征重要性以及筛选特征

    1.输出XGBoost特征的重要性

    from matplotlib import pyplot
    pyplot.bar(range(len(model_XGB.feature_importances_)), model_XGB.feature_importances_)
    pyplot.show()

    from matplotlib import pyplot
    pyplot.bar(range(len(model_XGB.feature_importances_)), model_XGB.feature_importances_)
    pyplot.show()

    也可以使用XGBoost内置的特征重要性绘图函数

    # plot feature importance using built-in function
    from xgboost import plot_importance
    plot_importance(model_XGB)
    pyplot.show()
    # plot feature importance using built-in function
    from xgboost import plot_importance
    plot_importance(model_XGB)
    pyplot.show()

    2.根据特征重要性筛选特征

    from numpy import sort
    from sklearn.feature_selection import SelectFromModel
    
    # Fit model using each importance as a threshold
    thresholds = sort(model_XGB.feature_importances_)
    for thresh in thresholds:
      # select features using threshold
      selection = SelectFromModel(model_XGB, threshold=thresh, prefit=True)
      select_X_train = selection.transform(X_train)
      # train model
      selection_model = XGBClassifier()
      selection_model.fit(select_X_train, y_train)
    # eval model
      select_X_test = selection.transform(X_test)
      y_pred = selection_model.predict(select_X_test)
      predictions = [round(value) for value in y_pred]
      accuracy = accuracy_score(y_test, predictions)
      print("Thresh=%.3f, n=%d, Accuracy: %.2f%%" % (thresh, select_X_train.shape[1],
          accuracy*100.0))
    复制代码
    from numpy import sort
    from sklearn.feature_selection import SelectFromModel
    
    # Fit model using each importance as a threshold
    thresholds = sort(model_XGB.feature_importances_)
    for thresh in thresholds:
      # select features using threshold
      selection = SelectFromModel(model_XGB, threshold=thresh, prefit=True)
      select_X_train = selection.transform(X_train)
      # train model
      selection_model = XGBClassifier()
      selection_model.fit(select_X_train, y_train)
    # eval model
      select_X_test = selection.transform(X_test)
      y_pred = selection_model.predict(select_X_test)
      predictions = [round(value) for value in y_pred]
      accuracy = accuracy_score(y_test, predictions)
      print("Thresh=%.3f, n=%d, Accuracy: %.2f%%" % (thresh, select_X_train.shape[1],
          accuracy*100.0))
    复制代码

     参考:https://blog.csdn.net/u011630575/article/details/79423162

  • 相关阅读:
    将指定byte数组以16进制的形式打印到控制台
    Java输出当前的日期(年月日时分秒毫秒)
    JAVA API从MongoDB中读取数据
    Hbase API 写入操作代码,基于hbase-client 1.0.3版本
    关于ROW_NUMBER函数的使用(The use of ROW_NUMBER function )
    网络知识点
    C++学习笔记
    Linux网络编程--tinyhttpd
    Linux多线程编程
    inet_ntop返回值错误
  • 原文地址:https://www.cnblogs.com/tan2810/p/11154630.html
Copyright © 2011-2022 走看看