zoukankan      html  css  js  c++  java
  • Spark CrossValidator

    1、概述

    ML中的一项重要任务是模型选择,或使用数据为给定任务找到最佳模型或参数。这也称为tuning

    可以针对单个估算器(例如LogisticRegression)进行调整,也可以针对包括多个算法,特征化和其他步骤的整个管道进行调整。用户可以一次调整整个管道,而不必分别调整管道中的每个元素。

    MLlib使用诸如CrossValidator和TrainValidationSplit之类的工具支持模型选择。这些工具需要以下各项:

    •     Estimator估计器:要调整的算法或管道
    •     一组ParamMaps:可供选择的参数,有时也称为“参数网格”以进行搜索
    •     Evaluator评估者:衡量拟合模型对保留的测试数据的良好程度的度量


    在较高级别,这些模型选择工具的工作方式如下:

    •     他们将输入数据分为单独的训练和测试数据集。
    •     对于每对(训练,测试),它们都会遍历一组ParamMap:对于每个ParamMap,他们使用这些参数拟合Estimator,获得拟合的Model,然后使用Evaluator评估Model的性能。
    •     他们选择由性能最佳的参数集生成的模型。


    该评估器可以是用于回归问题的RegressionEvaluator,用于二元数据的BinaryClassificationEvaluator或用于多元问题的MulticlassClassificationEvaluator。

    每个评估器中的setMetricName方法都可以覆盖用于选择最佳ParamMap的默认度量。

    2、Cross-Validation交叉验证

    CrossValidator首先将数据集分成一组折叠,这些折叠用作单独的训练和测试数据集。例如,k = 3
    折叠后,CrossValidator将生成3个(训练,测试)数据集对,每个对都使用2/3的数据进行训练,并使用1/3的数据进行测试。

    为了评估特定的ParamMap,CrossValidator为3个模型(通过将Estimator拟合到3个不同的(训练,测试)数据集对上)计算出平均评估指标。
    确定最佳的ParamMap之后,CrossValidator最终使用最佳的ParamMap和整个数据集重新拟合Estimator。

    3、适用情况

    当数据集比较小的时候

    交叉验证可以“充分利用”有限的数据找到合适的模型参数,防止过度拟合

    一般做深度学习跑标准数据集的时候用不到

    4、code

    package com.home.spark.ml
    
    import org.apache.spark.SparkConf
    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.classification.LogisticRegression
    import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
    import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
    import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
    import org.apache.spark.sql.{Row, SparkSession}
    import org.apache.spark.ml.linalg.Vector
    
    /**
      * @Description: 交叉验证选择最佳模型参数
      * 请注意,在参数网格上进行交叉验证的成本很高。
      * 例如,在下面的示例中,参数网格具有3个值的hashingTF.numFeatures和2个值的lr.regParam,而CrossValidator使用2次折叠。这乘以(3×2)×2 = 12
      * 训练不同的模型。在实际设置中,尝试更多的参数并使用更多的折叠数(通常是k = 3和k = 10)是很常见的。
      * 换句话说,使用CrossValidator可能非常昂贵。但是,这也是一种公认​​的用于选择参数的方法,该方法在统计上比启发式手动调整更合理。
      **/
    object Ex_CrossValidator {
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf(true).setAppName("spark ml model selection").setMaster("local[2]")
        val spark = SparkSession.builder().config(conf).getOrCreate()
    
    //    import spark.implicits._
    
        // Prepare training data from a list of (id, text, label) tuples.
        val training = spark.createDataFrame(Seq(
          (0L, "a b c d e spark", 1.0),
          (1L, "b d", 0.0),
          (2L, "spark f g h", 1.0),
          (3L, "hadoop mapreduce", 0.0),
          (4L, "b spark who", 1.0),
          (5L, "g d a y", 0.0),
          (6L, "spark fly", 1.0),
          (7L, "was mapreduce", 0.0),
          (8L, "e spark program", 1.0),
          (9L, "a e c l", 0.0),
          (10L, "spark compile", 1.0),
          (11L, "hadoop software", 0.0)
        )).toDF("id", "text", "label")
    
        // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
        val tokenizer = new Tokenizer()
          .setInputCol("text")
          .setOutputCol("words")
        val hashingTF = new HashingTF()
          .setInputCol(tokenizer.getOutputCol)
          .setOutputCol("features")
        val lr = new LogisticRegression()
          .setMaxIter(10)
        val pipeline = new Pipeline()
          .setStages(Array(tokenizer, hashingTF, lr))
    
        // We use a ParamGridBuilder to construct a grid of parameters to search over.
        // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
        // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
        val paramGrid = new ParamGridBuilder()
          .addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
          .addGrid(lr.regParam, Array(0.1, 0.01))
          .build()
    
        // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
        // This will allow us to jointly choose parameters for all Pipeline stages.
        // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
        // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric
        // is areaUnderROC.
        val cv = new CrossValidator()
          .setEstimator(pipeline)
          .setEvaluator(new BinaryClassificationEvaluator)
          .setEstimatorParamMaps(paramGrid)
          .setNumFolds(2)  // Use 3+ in practice
          .setParallelism(2)  // Evaluate up to 2 parameter settings in parallel
    
        // Run cross-validation, and choose the best set of parameters.
        val cvModel = cv.fit(training)
    
        // Prepare test documents, which are unlabeled (id, text) tuples.
        val test = spark.createDataFrame(Seq(
          (4L, "spark i j k"),
          (5L, "l m n"),
          (6L, "mapreduce spark"),
          (7L, "apache hadoop")
        )).toDF("id", "text")
    
        // Make predictions on test documents. cvModel uses the best model found (lrModel).
        cvModel.transform(test)
          .select("id", "text", "probability", "prediction")
          .collect()
          .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
            println(s"($id, $text) --> prob=$prob, prediction=$prediction")
          }
    
    
        spark.stop()
      }
    }

    result:

    (4, spark i j k) --> prob=[0.25806842225846466,0.7419315777415353], prediction=1.0
    (5, l m n) --> prob=[0.9185597412653913,0.08144025873460858], prediction=0.0
    (6, mapreduce spark) --> prob=[0.43203205663918753,0.5679679433608125], prediction=1.0
    (7, apache hadoop) --> prob=[0.6766082856652199,0.32339171433478003], prediction=0.0
  • 相关阅读:
    SQL where 条件顺序对性能的影响有哪些
    性能优化实战-join与where条件执行顺序
    执行计划--WHERE条件的先后顺序对执行计划的影响
    要提高SQL查询效率where语句条件的先后次序应如何写
    winform渐变窗口显示/关闭
    Linq to Object实现分页获取数据
    无法将类型“System.Collections.Generic.IEnumerable<EmailSystem.Model.TemplateInfo>”隐式转换为“System.Collections.Generic.List<EmailSystem.Model.TemplateInf
    求本年、本月、本周等数据
    DataTable无法使用AsEnumerable ()的解决办法
    C#截取指定字符串函数
  • 原文地址:https://www.cnblogs.com/asker009/p/12426964.html
Copyright © 2011-2022 走看看