zoukankan      html  css  js  c++  java
  • 【机器学习】集成学习之sklearn中的xgboost基本用法

    原创博文,转载请注明出处!本文代码的github地址    博客索引地址

    1.数据集

          数据集使用sklearn自带的手写数字识别数据集mnist,通过函数datasets导入。mnist共1797个样本,8*8个特征,标签为0~9十个数字。

      1 ### 载入数据
      2 from sklearn import datasets    # 载入数据集
      3 digits = datasets.load_digits() # 载入mnist数据集
      4 print(digits.data.shape)        # 打印输入空间维度
      5 print(digits.target.shape)      # 打印输出空间维度
      6 """
      7 (1797, 64)
      8 (1797,)
      9 """

    2.数据集分割

          sklearn.model_selection中train_test_split函数划分数据集,其中参数test_size为测试集所占的比例,random_state为随机种子(为了能够复现实验结果而设定)。

      1 ### 数据分割
      2 from sklearn.model_selection import train_test_split                 # 载入数据分割函数train_test_split
      3 x_train,x_test,y_train,y_test = train_test_split(digits.data,        # 特征空间
      4                                                  digits.target,      # 输出空间
      5                                                  test_size = 0.3,    # 测试集占30%
      6                                                  random_state = 33)  # 为了复现实验,设置一个随机数
      7 

    3.模型相关(载入模型--训练模型--模型预测)

          XGBClassifier.fit()函数用于训练模型,XGBClassifier.predict()函数为使用模型做预测。

      1 ### 模型相关
      2 from xgboost import XGBClassifier
      3 model = XGBClassifier()               # 载入模型(模型命名为model)
      4 model.fit(x_train,y_train)            # 训练模型(训练集)
      5 y_pred = model.predict(x_test)        # 模型预测(测试集),y_pred为预测结果

    4.性能评估

          sklearn.metrics中accuracy_score函数用来判断模型预测的准确度。

      1 ### 性能度量
      2 from sklearn.metrics import accuracy_score   # 准确率
      3 accuracy = accuracy_score(y_test,y_pred)
      4 print("accuarcy: %.2f%%" % (accuracy*100.0))
      5 
      6 """
      7 95.0%
      8 """

    5.特征重要性

          xgboost分析了特征的重要程度,通过函数plot_importance绘制图片。

      1 ### 特征重要性
      2 import matplotlib.pyplot as plt
      3 from xgboost import plot_importance
      4 fig,ax = plt.subplots(figsize=(10,15))
      5 plot_importance(model,height=0.5,max_num_features=64,ax=ax)
      6 plt.show()

    image

    6.完整代码

      1 # -*- coding: utf-8 -*-
      2 """
      3 ###############################################################################
      4 # 作者:wanglei5205
      5 # 邮箱:wanglei5205@126.com
      6 # 代码:http://github.com/wanglei5205
      7 # 博客:http://cnblogs.com/wanglei5205
      8 # 目的:xgboost基本用法
      9 ###############################################################################
     10 """
     11 ### load module
     12 from sklearn import datasets
     13 from sklearn.model_selection import train_test_split
     14 from xgboost import XGBClassifier
     15 from sklearn.metrics import accuracy_score
     16 
     17 ### load datasets
     18 digits = datasets.load_digits()
     19 
     20 ### data analysis
     21 print(digits.data.shape)   # 输入空间维度
     22 print(digits.target.shape) # 输出空间维度
     23 
     24 ### data split
     25 x_train,x_test,y_train,y_test = train_test_split(digits.data,
     26                                                  digits.target,
     27                                                  test_size = 0.3,
     28                                                  random_state = 33)
     29 
     30 ### fit model for train data
     31 model = XGBClassifier()
     32 model.fit(x_train,y_train)
     33 
     34 ### make prediction for test data
     35 y_pred = model.predict(x_test)
     36 
     37 ### model evaluate
     38 accuracy = accuracy_score(y_test,y_pred)
     39 print("accuarcy: %.2f%%" % (accuracy*100.0))
     40 """
     41 95.0%
     42 """
  • 相关阅读:
    vue移动端滚动插件BetterScroll
    vue商品推荐信息展示 案例
    css吸顶效果
    vue TabControl案例
    首页导航栏样式 案例
    HO引擎近况20210713
    go定时器--timer
    go定时器--Ticker
    Go测试--main测试
    Spring 核心技术 AOP 实例
  • 原文地址:https://www.cnblogs.com/wanglei5205/p/8578486.html
Copyright © 2011-2022 走看看