zoukankan      html  css  js  c++  java
  • 转 :scikit-learn的GBDT工具进行特征选取。

    http://blog.csdn.net/w5310335/article/details/48972587

    使用GBDT选取特征

    2015-03-31

    本文介绍如何使用scikit-learn的GBDT工具进行特征选取。

    为什麽选取特征


    有些特征意义不大,删除后不影响效果,甚至可能提升效果。

    关于GBDT(Gradient Boosting Decision Tree)


    可以参考:

    GBDT(MART)概念简介

    GBDT(MART) 迭代决策树入门教程 | 简介

    机器学习中的算法(1)-决策树模型组合之随机森林与GBDT

    如何在numpy数组中选取若干列或者行?


    >>> import numpy as np
    >>> tmp_a = np.array([[1,1], [0.4, 4], [1., 0.9]])
    >>> tmp_a
    array([[ 1. ,  1. ],  
           [ 0.4,  4. ],
           [ 1. ,  0.9]])
    >>> tmp_a[[0,1],:]  # 选第0、1行
    array([[ 1. ,  1. ],  
           [ 0.4,  4. ]])
    >>> tmp_a[np.array([True, False, True]), :]  # 选第0、2行
    array([[ 1. ,  1. ],  
           [ 1. ,  0.9]])
    >>> tmp_a[:,[0]]    # 选第0列
    array([[ 1. ],  
           [ 0.4],
           [ 1. ]])
    >>> tmp_a[:, np.array([True, False])]  # 选第0列
    array([[ 1. ],  
           [ 0.4],
           [ 1. ]])
    

    生成数据集


    参考基于贝叶斯的文本分类实战。部分方法在原始数据集的预测效果也在基于贝叶斯的文本分类实战这篇文章里。

    训练GBDT


    >>> from sklearn.ensemble import GradientBoostingClassifier
    >>> gbdt = GradientBoostingClassifier()
    >>> gbdt.fit(training_data, training_labels)  # 训练。喝杯咖啡吧
    GradientBoostingClassifier(init=None, learning_rate=0.1, loss='deviance',  
                  max_depth=3, max_features=None, max_leaf_nodes=None,
                  min_samples_leaf=1, min_samples_split=2,
                  min_weight_fraction_leaf=0.0, n_estimators=100,
                  random_state=None, subsample=1.0, verbose=0,
                  warm_start=False)
    >>> gbdt.feature_importances_   # 据此选取重要的特征
    array([  2.08644807e-06,   0.00000000e+00,   8.93452010e-04, ...,  
             5.12199658e-04,   0.00000000e+00,   0.00000000e+00])
    >>> gbdt.feature_importances_.shape
    (19630,)
    

    看一下GBDT的分类效果:

    >>> gbdt_predict_labels = gbdt.predict(test_data)
    >>> sum(gbdt_predict_labels==test_labels)  # 比 多项式贝叶斯 差许多
    414  
    

    新的训练集和测试集(只保留了1636个特征,原先是19630个特征):

    >>> new_train_data = training_data[:, feature_importances>0]
    >>> new_train_data.shape  # 只保留了1636个特征
    (1998, 1636)
    >>> new_test_data = test_data[:, feature_importances>0]
    >>> new_test_data.shape
    (509, 1636)
    

    使用多项式贝叶斯处理新数据


    >>> from sklearn.naive_bayes import MultinomialNB
    >>> bayes = MultinomialNB() 
    >>> bayes.fit(new_train_data, training_labels)
    MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True)  
    >>> bayes_predict_labels = bayes.predict(new_test_data)
    >>> sum(bayes_predict_labels == test_labels)   # 之前预测正确的样本数量是454
    445  
    

    使用伯努利贝叶斯处理新数据


    >>> from sklearn.naive_bayes import BernoulliNB
    >>> bayes2 = BernoulliNB()
    >>> bayes2.fit(new_train_data, training_labels)
    BernoulliNB(alpha=1.0, binarize=0.0, class_prior=None, fit_prior=True)  
    >>> bayes_predict_labels = bayes2.predict(new_test_data)
    >>> sum(bayes_predict_labels == test_labels)   # 之前预测正确的样本数量是387
    422  
    

    使用Logistic回归处理新数据


    对原始特征组成的数据集:

    >>> from sklearn.linear_model import LogisticRegression
    >>> lr1 = LogisticRegression()
    >>> lr1.fit(training_data, training_labels)
    LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,  
              intercept_scaling=1, max_iter=100, multi_class='ovr',
              penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
              verbose=0)
    >>> lr1_predict_labels = lr1.predict(test_data)
    >>> sum(lr1_predict_labels == test_labels)
    446  
    

    对削减后的特征组成的数据集:

    >>> lr2 = LogisticRegression()
    >>> lr2.fit(new_train_data, training_labels)
    LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,  
              intercept_scaling=1, max_iter=100, multi_class='ovr',
              penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
              verbose=0)
    >>> lr2_predict_labels = lr2.predict(new_test_data)
    >>> sum(lr2_predict_labels == test_labels)  # 正确率略微提升
    449  
    

    (完)

  • 相关阅读:
    百度开发者中心BAE新建Java应用
    微信公众平台开发(三)位置信息的识别
    确定路名、标志性建筑和商场名的经度纬度
    Eclipse中Java Project转换为Java Web Project
    你应该知道的8个Java牛人
    周边信息查询
    在Java中避免空指针异常(Null Pointer Exception)
    google guava使用例子/示范(一)
    docker 相关链接
    HashMap 的数据结构
  • 原文地址:https://www.cnblogs.com/DjangoBlog/p/6211269.html
Copyright © 2011-2022 走看看