zoukankan      html  css  js  c++  java
  • 转 :scikit-learn的GBDT工具进行特征选取。

    http://blog.csdn.net/w5310335/article/details/48972587

    使用GBDT选取特征

    2015-03-31

    本文介绍如何使用scikit-learn的GBDT工具进行特征选取。

    为什麽选取特征


    有些特征意义不大,删除后不影响效果,甚至可能提升效果。

    关于GBDT(Gradient Boosting Decision Tree)


    可以参考:

    GBDT(MART)概念简介

    GBDT(MART) 迭代决策树入门教程 | 简介

    机器学习中的算法(1)-决策树模型组合之随机森林与GBDT

    如何在numpy数组中选取若干列或者行?


    >>> import numpy as np
    >>> tmp_a = np.array([[1,1], [0.4, 4], [1., 0.9]])
    >>> tmp_a
    array([[ 1. ,  1. ],  
           [ 0.4,  4. ],
           [ 1. ,  0.9]])
    >>> tmp_a[[0,1],:]  # 选第0、1行
    array([[ 1. ,  1. ],  
           [ 0.4,  4. ]])
    >>> tmp_a[np.array([True, False, True]), :]  # 选第0、2行
    array([[ 1. ,  1. ],  
           [ 1. ,  0.9]])
    >>> tmp_a[:,[0]]    # 选第0列
    array([[ 1. ],  
           [ 0.4],
           [ 1. ]])
    >>> tmp_a[:, np.array([True, False])]  # 选第0列
    array([[ 1. ],  
           [ 0.4],
           [ 1. ]])
    

    生成数据集


    参考基于贝叶斯的文本分类实战。部分方法在原始数据集的预测效果也在基于贝叶斯的文本分类实战这篇文章里。

    训练GBDT


    >>> from sklearn.ensemble import GradientBoostingClassifier
    >>> gbdt = GradientBoostingClassifier()
    >>> gbdt.fit(training_data, training_labels)  # 训练。喝杯咖啡吧
    GradientBoostingClassifier(init=None, learning_rate=0.1, loss='deviance',  
                  max_depth=3, max_features=None, max_leaf_nodes=None,
                  min_samples_leaf=1, min_samples_split=2,
                  min_weight_fraction_leaf=0.0, n_estimators=100,
                  random_state=None, subsample=1.0, verbose=0,
                  warm_start=False)
    >>> gbdt.feature_importances_   # 据此选取重要的特征
    array([  2.08644807e-06,   0.00000000e+00,   8.93452010e-04, ...,  
             5.12199658e-04,   0.00000000e+00,   0.00000000e+00])
    >>> gbdt.feature_importances_.shape
    (19630,)
    

    看一下GBDT的分类效果:

    >>> gbdt_predict_labels = gbdt.predict(test_data)
    >>> sum(gbdt_predict_labels==test_labels)  # 比 多项式贝叶斯 差许多
    414  
    

    新的训练集和测试集(只保留了1636个特征,原先是19630个特征):

    >>> new_train_data = training_data[:, feature_importances>0]
    >>> new_train_data.shape  # 只保留了1636个特征
    (1998, 1636)
    >>> new_test_data = test_data[:, feature_importances>0]
    >>> new_test_data.shape
    (509, 1636)
    

    使用多项式贝叶斯处理新数据


    >>> from sklearn.naive_bayes import MultinomialNB
    >>> bayes = MultinomialNB() 
    >>> bayes.fit(new_train_data, training_labels)
    MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True)  
    >>> bayes_predict_labels = bayes.predict(new_test_data)
    >>> sum(bayes_predict_labels == test_labels)   # 之前预测正确的样本数量是454
    445  
    

    使用伯努利贝叶斯处理新数据


    >>> from sklearn.naive_bayes import BernoulliNB
    >>> bayes2 = BernoulliNB()
    >>> bayes2.fit(new_train_data, training_labels)
    BernoulliNB(alpha=1.0, binarize=0.0, class_prior=None, fit_prior=True)  
    >>> bayes_predict_labels = bayes2.predict(new_test_data)
    >>> sum(bayes_predict_labels == test_labels)   # 之前预测正确的样本数量是387
    422  
    

    使用Logistic回归处理新数据


    对原始特征组成的数据集:

    >>> from sklearn.linear_model import LogisticRegression
    >>> lr1 = LogisticRegression()
    >>> lr1.fit(training_data, training_labels)
    LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,  
              intercept_scaling=1, max_iter=100, multi_class='ovr',
              penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
              verbose=0)
    >>> lr1_predict_labels = lr1.predict(test_data)
    >>> sum(lr1_predict_labels == test_labels)
    446  
    

    对削减后的特征组成的数据集:

    >>> lr2 = LogisticRegression()
    >>> lr2.fit(new_train_data, training_labels)
    LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,  
              intercept_scaling=1, max_iter=100, multi_class='ovr',
              penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
              verbose=0)
    >>> lr2_predict_labels = lr2.predict(new_test_data)
    >>> sum(lr2_predict_labels == test_labels)  # 正确率略微提升
    449  
    

    (完)

  • 相关阅读:
    shell脚本根据端口号kill掉进程
    使用netstat -ano 查看机器端口的占用情况(windows环境)
    分享一两个小工具,
    将压缩文件伪装图片格式文件以及将python文件转化为exe文件(测试完,真的有效)
    celery 异步任务 周期任务 定时任务的实现
    wsgi、uwsgi、asgi协议的关系
    centos7忘记密码更改步骤
    工作遇到的坑以及自己的学习悟道之道
    案例小集锦
    asp.net mvc部署
  • 原文地址:https://www.cnblogs.com/DjangoBlog/p/6211269.html
Copyright © 2011-2022 走看看