zoukankan      html  css  js  c++  java
  • Scikit-Learn与决策树

    Scikit-Learn(决策树)可以用于方法分类回归。

    一、分类

    sklearn.tree.DecisionTreeClassifier(criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_split=1e-07, class_weight=None, presort=False)参数探讨

    • criterion :('Gini'、‘entropy’)表示在基于特征划分数据集合时,选择特征的标准。默认是’gini‘,即'Gini impurity'(Gini不纯度),还可以是criterion='entropy'。Gini不纯度表示该Gini度量是指随机选择集合中的元素,根据集合中label的分布将该元素赋予分类,该元素分类错误的几率;entropy则表示采用信息增益来选择特征。别人看法:criterion=entropy应该理解为决策树采用的是ID3算法,而不是cart树。
    • splitter :('best' , 'random')表示在构造树时,选择结点的原则,默认是splitter='best',即选择最好的特征点分类,比如基于信息增益分类时,则选择信息增益最大的特征点,还可以是'random'
    • max_depth  :int,default=None,表示树的最大深度。默认为"None",表示树的最大深度。如果是"None",则节点会一直扩展直到所有的叶子都是纯的或者所有的叶子节点都包含少于min_samples_split个样本点。忽视max_leaf_nodes是不是为None。
    • min_samples_split  :int,float,optional(default=2),区分一个内部节点需要的最少的样本数。1.如果是int,将其最为最小的样本数。2.如果是float,min_samples_split是一个百分率并且ceil(min_samples_split*n_samples)是每个分类需要的样本数。ceil是取大于或等于指定表达式的最小整数。
    •  min_samples_leaf  :int,float,optional(default=1),一个叶节点所需要的最小样本数。 1.如果是int,则其为最小样本数。 2.如果是float,则它是一个百分率并且ceil(min_samples_leaf*n_samples)是每个节点所需的样本数。
    •  min_weight_fraction_leaf :float,optional(default=0),如果设置为0,则表示所有样本的权重是一样的
    • max_features :这个参数表示在划分数据集时考虑的最多的特征值数量,根据数据类型表示的意义也不同。int值,在每次split时,最大特征数;float,表示百分数,即(max_features * n_features);'auto'->max_features=sqrt(n_features);'sqrt'->max_features=sqrt(n_features);
    • max_leaf_nodes  :int,None,optional(default=None),主要是在最优分类中考虑
    • class_weight  :dict,list of dicts,"Banlanced" or None,可选(默认为None)如果没有指定,所有类的权值都为1。对于多输出问题,一列字典的顺序可以与一列y的次序相同。 "balanced"模型使用y的值去自动适应权值,并且是以输入数据中类的频率的反比例。如果sample_weight已经指定了,这些权值将于samples以合适的方法相乘。
    • persort  :bool,可选(默认为False)是否预分类数据以加速训练时最好分类的查找。在有大数据集的决策树中,如果设为true可能会减慢训练的过程。当使用一个小数据集或者一个深度受限的决策树中,可以减速训练的过程。
    • min_impurity_split :float, optional (default=1e-7),树增长停止阈值,仅仅当他的impurity超过阈值时才会继续向下分解,否则会成为叶结点

     例子:

    from sklearn import tree
    
    X = [[1, 1],[1, 1], [1, 0],[0, 1], [0, 1]]
    Y = [1, 1, 0, 0, 0]
    
    clf = tree.DecisionTreeClassifier(criterion='entropy')
    clf = clf.fit(X, Y)
    #predict_proba(Xcheck_input=True) 预测x中的分类概率 result
    = clf.predict([0,0]) print result

    训练后,我们可以使用导出器以Graphviz(需要单独安装)格式导出树export_graphviz 。

    with open("iris.dot", 'w') as f:
        f = tree.export_graphviz(clf, out_file=f)

    然后我们可以使用的Graphviz的dot工具来创建一个PDF文件(或任何其他支持的文件类型)

    dot -Tpdf iris.dot -o iris.pdf

    二、回归

    sklearn.tree.DecisionTreeRegressor(criterion='mse', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_split=1e-07, presort=False)

    • criteria:string,可选(default =“mse”)。测量分割质量的功能。对于均方误差,支持的标准是“mse”,其等于作为特征选择标准的方差减小,以及平均绝对误差的“mae”。

    其他参与与DecisionTreeClassifier类似

    from sklearn import tree
    
    X = [[1],[2],[3],[4],[5],[6]]
    Y = [1,2,3,4,5,6]
    
    clf = tree.DecisionTreeRegressor(criterion='mae')
    clf = clf.fit(X, Y)
    result = clf.predict([4])
    print result

    输出:

    [ 4.]
  • 相关阅读:
    labview 中的一些简写全称
    socket
    putty
    在波形图表中显示多条曲线
    简单的通电延时触发电路
    Linux sed 批量替换多个文件中的字符串
    PhpMyAdmin管理,登录多台远程MySQL服务器
    MySQL客户端工具推荐
    Redis的几个认识误区
    Redis 的 5 个常见使用场景
  • 原文地址:https://www.cnblogs.com/lovephysics/p/7235089.html
Copyright © 2011-2022 走看看