zoukankan      html  css  js  c++  java
  • 我的spark python 决策树实例

    from numpy import array
    from pyspark.mllib.regression import LabeledPoint
    from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
    from pyspark import SparkContext
    from pyspark.mllib.evaluation import BinaryClassificationMetrics
    
    sc = SparkContext(appName="PythonDecisionTreeClassificationExample")
    data = [
         LabeledPoint(0.0, [0.0]),
         LabeledPoint(1.0, [1.0]),
         LabeledPoint(0.0, [-2.0]),
         LabeledPoint(0.0, [-1.0]),
         LabeledPoint(0.0, [-3.0]),
         LabeledPoint(1.0, [4.0]),
         LabeledPoint(1.0, [4.5]),
         LabeledPoint(1.0, [4.9]),
         LabeledPoint(1.0, [3.0])
     ]
    all_data = sc.parallelize(data) 
    (trainingData, testData) = all_data.randomSplit([0.8, 0.2])
    
    # model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
    model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                             impurity='gini', maxDepth=5, maxBins=32)
    print(model)
    print(model.toDebugString())
    model.predict(array([1.0]))
    model.predict(array([0.0]))
    rdd = sc.parallelize([[1.0], [0.0]])
    model.predict(rdd).collect()
    
    predictions = model.predict(testData.map(lambda x: x.features))
    labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)

      predictionsAndLabels = predictions.zip(testData.map(lambda lp: lp.label))

    metrics = BinaryClassificationMetrics(predictionsAndLabels )
    print "AUC=%f PR=%f" % (metrics.areaUnderROC, metrics.areaUnderPR)
    
    testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
    print('Test Error = ' + str(testErr))
    print('Learned classification tree model:')
    print(model.toDebugString())
    
    # Save and load model
    model.save(sc, "./myDecisionTreeClassificationModel")
    sameModel = DecisionTreeModel.load(sc, "./myDecisionTreeClassificationModel")
  • 相关阅读:
    【PAT甲级】1128 N Queens Puzzle (20分)
    Codeforces Global Round 7D(马拉车/PAM,回文串)
    【PAT甲级】1127 ZigZagging on a Tree (30分)(已知中序后序蛇形输出层次遍历)
    SDOI2012 体育课
    APIO2018 Circle selection 选圆圈
    [科技] 求数列的前k次方和
    APIO2016 Fireworks
    CTSC2018 暴力写挂
    ZJOI2018 胖
    SDOI2017 数字表格
  • 原文地址:https://www.cnblogs.com/bonelee/p/7151341.html
Copyright © 2011-2022 走看看