zoukankan      html  css  js  c++  java
  • 我的spark python 决策树实例

    from numpy import array
    from pyspark.mllib.regression import LabeledPoint
    from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
    from pyspark import SparkContext
    from pyspark.mllib.evaluation import BinaryClassificationMetrics
    
    sc = SparkContext(appName="PythonDecisionTreeClassificationExample")
    data = [
         LabeledPoint(0.0, [0.0]),
         LabeledPoint(1.0, [1.0]),
         LabeledPoint(0.0, [-2.0]),
         LabeledPoint(0.0, [-1.0]),
         LabeledPoint(0.0, [-3.0]),
         LabeledPoint(1.0, [4.0]),
         LabeledPoint(1.0, [4.5]),
         LabeledPoint(1.0, [4.9]),
         LabeledPoint(1.0, [3.0])
     ]
    all_data = sc.parallelize(data) 
    (trainingData, testData) = all_data.randomSplit([0.8, 0.2])
    
    # model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})
    model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={},
                                             impurity='gini', maxDepth=5, maxBins=32)
    print(model)
    print(model.toDebugString())
    model.predict(array([1.0]))
    model.predict(array([0.0]))
    rdd = sc.parallelize([[1.0], [0.0]])
    model.predict(rdd).collect()
    
    predictions = model.predict(testData.map(lambda x: x.features))
    labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)

      predictionsAndLabels = predictions.zip(testData.map(lambda lp: lp.label))

    metrics = BinaryClassificationMetrics(predictionsAndLabels )
    print "AUC=%f PR=%f" % (metrics.areaUnderROC, metrics.areaUnderPR)
    
    testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
    print('Test Error = ' + str(testErr))
    print('Learned classification tree model:')
    print(model.toDebugString())
    
    # Save and load model
    model.save(sc, "./myDecisionTreeClassificationModel")
    sameModel = DecisionTreeModel.load(sc, "./myDecisionTreeClassificationModel")
  • 相关阅读:
    多态
    封装,继承,多态
    基本类型和引用类型的区别
    第七天 面向对象
    什么是Java线程池
    游戏内核架构
    放松
    静不下来心写代码
    速度和正确率
    理顺思路
  • 原文地址:https://www.cnblogs.com/bonelee/p/7151341.html
Copyright © 2011-2022 走看看