zoukankan      html  css  js  c++  java
  • spark 针对决策树进行交叉验证

    from pyspark import SparkContext, SQLContext
    from pyspark.ml import Pipeline
    from pyspark.ml.classification import DecisionTreeClassifier
    from pyspark.ml.feature import StringIndexer, VectorIndexer
    from pyspark.ml.evaluation import MulticlassClassificationEvaluator
    
    # Load the data stored in LIBSVM format as a DataFrame.
    data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
    
    # Index labels, adding metadata to the label column.
    # Fit on whole dataset to include all labels in index.
    labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data)
    # Automatically identify categorical features, and index them.
    # We specify maxCategories so features with > 4 distinct values are treated as continuous.
    featureIndexer =
        VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data)
    
    # Split the data into training and test sets (30% held out for testing)
    (trainingData, testData) = data.randomSplit([0.7, 0.3])
    
    # Train a DecisionTree model.
    dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures")
    
    # Chain indexers and tree in a Pipeline
    pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt])
    
    # Train model.  This also runs the indexers.
    model = pipeline.fit(trainingData)
    
    # Make predictions.
    predictions = model.transform(testData)
    
    # Select example rows to display.
    predictions.select("prediction", "indexedLabel", "features").show(5)
    
    # Select (prediction, true label) and compute test error
    evaluator = MulticlassClassificationEvaluator(
        labelCol="indexedLabel", predictionCol="prediction", metricName="precision")
    accuracy = evaluator.evaluate(predictions)
    print("Test Error = %g " % (1.0 - accuracy))
    
    treeModel = model.stages[2]
    # summary only
    print(treeModel)
    
    
    #############################
    
    from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
    
    # Create ParamGrid for Cross Validation
    paramGrid = (ParamGridBuilder()
                 .addGrid(lr.regParam, [0.01, 0.5, 2.0])
                 .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])
                 .addGrid(lr.maxIter, [1, 5, 10])
                 .build())
    Copy to clipboardCopy
    # Create 5-fold CrossValidator
    cv = CrossValidator(estimator=lr, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=5)
    
    # Run cross validations
    cvModel = cv.fit(trainingData)
    # this will likely take a fair amount of time because of the amount of models that we're creating and testing
    
    # Use test set here so we can measure the accuracy of our model on new data
    predictions = cvModel.transform(testData)
    
    # cvModel uses the best model found from the Cross Validation
    # Evaluate best model
    evaluator.evaluate(predictions)
    
    #We can also access the model’s feature weights and intercepts easily
    
    
    print 'Model Intercept: ', cvModel.bestModel.intercept
    ML provides CrossValidator class which can be used to perform cross-validation and parameter search. Assuming your data is already preprocessed you can add cross-validation as follows:
    
    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
    import org.apache.spark.ml.classification.RandomForestClassifier
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    
    // [label: double, features: vector]
    trainingData org.apache.spark.sql.DataFrame = ??? 
    val nFolds: Int = ???
    val NumTrees: Int = ???
    val metric: String = ???
    
    val rf = new RandomForestClassifier()
      .setLabelCol("label")
      .setFeaturesCol("features")
      .setNumTrees(NumTrees)
    
    val pipeline = new Pipeline().setStages(Array(rf)) 
    
    val paramGrid = new ParamGridBuilder().build() // No parameter search
    
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      // "f1" (default), "weightedPrecision", "weightedRecall", "accuracy"
      .setMetricName(metric) 
    
    val cv = new CrossValidator()
      // ml.Pipeline with ml.classification.RandomForestClassifier
      .setEstimator(pipeline)
      // ml.evaluation.MulticlassClassificationEvaluator
      .setEvaluator(evaluator) 
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(nFolds)
    
    val model = cv.fit(trainingData) // trainingData: DataFrame
    Using PySpark:
    
    from pyspark.ml import Pipeline
    from pyspark.ml.classification import RandomForestClassifier
    from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
    from pyspark.ml.evaluation import MulticlassClassificationEvaluator
    
    trainingData = ... # DataFrame[label: double, features: vector]
    numFolds = ... # Integer
    
    rf = RandomForestClassifier(labelCol="label", featuresCol="features")
    evaluator = MulticlassClassificationEvaluator() # + other params as in Scala    
    
    pipeline = Pipeline(stages=[rf])
    
    crossval = CrossValidator(
        estimator=pipeline,
        estimatorParamMaps=paramGrid,
        evaluator=evaluator,
        numFolds=numFolds)
    
    model = crossval.fit(trainingData)
  • 相关阅读:
    Linux下查看某个命令的参数
    Vue
    SpringBoot vue
    Axios 中文说明
    一个很有趣的示例Spring Boot项目,使用Giraphe CMS和Spring Boot
    微信公众号 文章的爬虫系统
    为RecyclerView添加item的点击事件
    Android NineGridLayout — 仿微信朋友圈和QQ空间的九宫格图片展示自定义控件
    今日头条 --新闻阅读器
    int android.support.v7.widget.RecyclerView$ViewHolder.mItemViewType' on a null.....
  • 原文地址:https://www.cnblogs.com/bonelee/p/7803849.html
Copyright © 2011-2022 走看看