zoukankan      html  css  js  c++  java
  • Spark ML机器学习库评估指标示例

    本文主要对 Spark ML库下模型评估指标的讲解,以下代码均以Jupyter Notebook进行讲解,Spark版本为2.4.5。模型评估指标位于包org.apache.spark.ml.evaluation下。

    模型评估指标是指测试集的评估指标,而不是训练集的评估指标

    1、回归评估指标

    RegressionEvaluator

    Evaluator for regression, which expects two input columns: prediction and label.

    评估指标支持以下几种:

    val metricName: Param[String]

    • "rmse" (default): root mean squared error
    • "mse": mean squared error
    • "r2": R2 metric
    • "mae": mean absolute error

    Examples

    # import dependencies
    import org.apache.spark.ml.regression.LinearRegression
    import org.apache.spark.ml.evaluation.RegressionEvaluator
    
    // Load training data
    val data = spark.read.format("libsvm")
      .load("/data1/software/spark/data/mllib/sample_linear_regression_data.txt")
    
    val lr = new LinearRegression()
      .setMaxIter(10)
      .setRegParam(0.3)
      .setElasticNetParam(0.8)
    
    // Fit the model
    val lrModel = lr.fit(training)
    
    // Summarize the model over the training set and print out some metrics
    val trainingSummary = lrModel.summary
    println(s"Train MSE: ${trainingSummary.meanSquaredError}")
    println(s"Train RMSE: ${trainingSummary.rootMeanSquaredError}")
    println(s"Train MAE: ${trainingSummary.meanAbsoluteError}")
    println(s"Train r2: ${trainingSummary.r2}")
    
    val predictions = lrModel.transform(test)
    
    // 计算精度
    val evaluator = new RegressionEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("mse")
    val accuracy = evaluator.evaluate(predictions)
    print(s"Test MSE: ${accuracy}")
    

    输出:

    Train MSE: 101.57870147367461
    Train RMSE: 10.078625971513905
    Train MAE: 8.108865602095849
    Train r2: 0.039467152584195975
    
    Test MSE: 114.28454406581636
    

    2、分类评估指标

    2.1 BinaryClassificationEvaluator

    Evaluator for binary classification, which expects two input columns: rawPrediction and label. The rawPrediction column can be of type double (binary 0/1 prediction, or probability of label 1) or of type vector (length-2 vector of raw predictions, scores, or label probabilities).

    评估指标支持以下几种:

    val metricName: Param[String]
    param for metric name in evaluation (supports "areaUnderROC" (default), "areaUnderPR")
    

    Examples

    import org.apache.spark.ml.classification.LogisticRegression
    import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    
    // Load training data
    val data = spark.read.format("libsvm").load("/data1/software/spark/data/mllib/sample_libsvm_data.txt")
    
    val Array(train, test) = data.randomSplit(Array(0.8, 0.2))
    
    val lr = new LogisticRegression()
      .setMaxIter(10)
      .setRegParam(0.3)
      .setElasticNetParam(0.8)
    
    // Fit the model
    val lrModel = lr.fit(train)
    
    // Summarize the model over the training set and print out some metrics
    val trainSummary = lrModel.summary
    println(s"Train accuracy: ${trainSummary.accuracy}")
    println(s"Train weightedPrecision: ${trainSummary.weightedPrecision}")
    println(s"Train weightedRecall: ${trainSummary.weightedRecall}")
    println(s"Train weightedFMeasure: ${trainSummary.weightedFMeasure}")
    
    val predictions = lrModel.transform(test)
    predictions.show(5)
    
    // 模型评估
    val evaluator = new BinaryClassificationEvaluator()
      .setLabelCol("label")
      .setRawPredictionCol("rawPrediction")
      .setMetricName("areaUnderROC")
    val auc = evaluator.evaluate(predictions)
    print(s"Test AUC: ${auc}")
    
    val mulEvaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("weightedPrecision")
    val precision = evaluator.evaluate(predictions)
    print(s"Test weightedPrecision: ${precision}")
    

    输出结果:

    Train accuracy: 0.9873417721518988
    Train weightedPrecision: 0.9876110961486668
    Train weightedRecall: 0.9873417721518987
    Train weightedFMeasure: 0.9873124561568825
    
    +-----+--------------------+--------------------+--------------------+----------+
    |label|            features|       rawPrediction|         probability|prediction|
    +-----+--------------------+--------------------+--------------------+----------+
    |  0.0|(692,[122,123,148...|[0.29746771419036...|[0.57382336211209...|       0.0|
    |  0.0|(692,[125,126,127...|[0.42262389447949...|[0.60411095396791...|       0.0|
    |  0.0|(692,[126,127,128...|[0.74220898710237...|[0.67747871191347...|       0.0|
    |  0.0|(692,[126,127,128...|[0.77729372618481...|[0.68509655708828...|       0.0|
    |  0.0|(692,[127,128,129...|[0.70928896866149...|[0.67024402884354...|       0.0|
    +-----+--------------------+--------------------+--------------------+----------+
    
    Test AUC: 1.0
    
    Test weightedPrecision: 1.0
    

    2.2 MulticlassClassificationEvaluator

    Evaluator for multiclass classification, which expects two input columns: prediction and label.

    注:既然适用于多分类,当然适用于上面的二分类

    评估指标支持如下几种:

    val metricName: Param[String]
    param for metric name in evaluation (supports "f1" (default), "weightedPrecision", "weightedRecall", "accuracy")
    

    Examples

    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.classification.DecisionTreeClassificationModel
    import org.apache.spark.ml.classification.DecisionTreeClassifier
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
    
    // Load the data stored in LIBSVM format as a DataFrame.
    val data = spark.read.format("libsvm").load("/data1/software/spark/data/mllib/sample_libsvm_data.txt")
    
    // Index labels, adding metadata to the label column.
    // Fit on whole dataset to include all labels in index.
    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
      .fit(data)
    // Automatically identify categorical features, and index them.
    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(4) // features with > 4 distinct values are treated as continuous.
      .fit(data)
    
    // Split the data into training and test sets (30% held out for testing).
    val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
    
    // Train a DecisionTree model.
    val dt = new DecisionTreeClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")
    
    // Convert indexed labels back to original labels.
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labels)
    
    // Chain indexers and tree in a Pipeline.
    val pipeline = new Pipeline()
      .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
    
    // Train model. This also runs the indexers.
    val model = pipeline.fit(trainingData)
    
    // Make predictions.
    val predictions = model.transform(testData)
    
    // Select example rows to display.
    predictions.select("predictedLabel", "label", "features").show(5)
    
    // Select (prediction, true label) and compute test error.
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(predictions)
    println(s"Test Error = ${(1.0 - accuracy)}")
    

    输出结果:

    +--------------+-----+--------------------+
    |predictedLabel|label|            features|
    +--------------+-----+--------------------+
    |           0.0|  0.0|(692,[95,96,97,12...|
    |           0.0|  0.0|(692,[122,123,124...|
    |           0.0|  0.0|(692,[122,123,148...|
    |           0.0|  0.0|(692,[126,127,128...|
    |           0.0|  0.0|(692,[126,127,128...|
    +--------------+-----+--------------------+
    only showing top 5 rows
    
    Test Error = 0.040000000000000036
    

    欢迎关注微信公众号

  • 相关阅读:
    List、Set、Map集合大杂烩
    Android的DatePicker和TimePicker-android学习之旅(三十八)
    Ubuntu 启动项、菜单 改动 防止隐藏
    Ehcache 整合Spring 使用页面、对象缓存
    Spring MVC 相关资料整理
    Spring + Spring MVC+Hibernate框架整合详细配置
    @RequestMapping 用法详解之地址映射
    关于时间统计问题
    Oracle之物化视图
    Oracle:高效插入大量数据经验之谈
  • 原文地址:https://www.cnblogs.com/songxitang/p/12404873.html
Copyright © 2011-2022 走看看