zoukankan      html  css  js  c++  java
  • python spark 决策树 入门demo

    Refer to the DecisionTree Python docs and DecisionTreeModel Python docs for more details on the API.

    from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
    from pyspark.mllib.util import MLUtils
    
    # Load and parse the data file into an RDD of LabeledPoint.
    data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt')
    # Split the data into training and test sets (30% held out for testing)
    (trainingData, testData) = data.randomSplit([0.7, 0.3])
    
    # Train a DecisionTree model.
    #  Empty categoricalFeaturesInfo indicates all features are continuous.
    model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                         impurity='gini', maxDepth=5, maxBins=32)
    
    # Evaluate model on test instances and compute test error
    predictions = model.predict(testData.map(lambda x: x.features))
    labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
    testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
    print('Test Error = ' + str(testErr))
    print('Learned classification tree model:')
    print(model.toDebugString())
    
    # Save and load model
    model.save(sc, "target/tmp/myDecisionTreeClassificationModel")
    sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel")
    
    Find full example code at "examples/src/main/python/mllib/decision_tree_classification_example.py" in the Spark repo.

    class pyspark.mllib.tree.DecisionTree[source]

    Learning algorithm for a decision tree model for classification or regression.

    New in version 1.1.0.

    classmethod trainClassifier(datanumClassescategoricalFeaturesInfoimpurity='gini'maxDepth=5maxBins=32minInstancesPerNode=1minInfoGain=0.0)[source]

    Train a decision tree model for classification.

    Parameters:
    • data – Training data: RDD of LabeledPoint. Labels should take values {0, 1, ..., numClasses-1}.
    • numClasses – Number of classes for classification.
    • categoricalFeaturesInfo – Map storing arity of categorical features. An entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.
    • impurity – Criterion used for information gain calculation. Supported values: “gini” or “entropy”. (default: “gini”)
    • maxDepth – Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1 means 1 internal node + 2 leaf nodes). (default: 5)
    • maxBins – Number of bins used for finding splits at each node. (default: 32)
    • minInstancesPerNode – Minimum number of instances required at child nodes to create the parent split. (default: 1)
    • minInfoGain – Minimum info gain required to create a split. (default: 0.0)
    Returns:

    DecisionTreeModel.

    Example usage:

    >>> from numpy import array
    >>> from pyspark.mllib.regression import LabeledPoint
    >>> from pyspark.mllib.tree import DecisionTree
    >>>
    >>> data = [
    ...     LabeledPoint(0.0, [0.0]),
    ...     LabeledPoint(1.0, [1.0]),
    ...     LabeledPoint(1.0, [2.0]),
    ...     LabeledPoint(1.0, [3.0])
    ... ]
    >>> model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
    >>> print(model)
    DecisionTreeModel classifier of depth 1 with 3 nodes
    
    >>> print(model.toDebugString())
    DecisionTreeModel classifier of depth 1 with 3 nodes
      If (feature 0 <= 0.0)
       Predict: 0.0
      Else (feature 0 > 0.0)
       Predict: 1.0
    
    >>> model.predict(array([1.0]))
    1.0
    >>> model.predict(array([0.0]))
    0.0
    >>> rdd = sc.parallelize([[1.0], [0.0]])
    >>> model.predict(rdd).collect()
    [1.0, 0.0]

    摘自:https://spark.apache.org/docs/latest/api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree

  • 相关阅读:
    GET or POST?
    ASP.NET AJAX简述
    C# 后台调用存储过程
    S,C,SC,表
    js判定浏览器的种类
    sql 数据表添加或删除或修改字段 alter
    sql判定数据表是否存在,存在删除,再新建表或修改表名
    sql 所有的表建好后,为表添加外键约束
    打开office弹出steup error 的解决办法
    a标签的属性
  • 原文地址:https://www.cnblogs.com/bonelee/p/7150483.html
Copyright © 2011-2022 走看看