zoukankan      html  css  js  c++  java
  • Spark2.0机器学习系列之2:基于Pipeline、交叉验证、ParamMap的模型选择和超参数调优

    Spark中的CrossValidation

    • Spark中采用是k折交叉验证 (k-fold cross validation)。举个例子,例如10折交叉验证(10-fold cross validation),将数据集分成10份,轮流将其中9份做训练1份做验证,10次的结果的均值作为对算法精度的估计。
    • 10折交叉检验最常见,是因为通过利用大量数据集、使用不同学习技术进行的大量试验,表明10折是获得最好误差估计的恰当选择,而且也有一些理论根据可以证明这一点。但这并非最终结论,争议仍然存在。而且似乎5折或者20折与10折所得出的结果也相差无几。
    • 交叉检验常用于分析模型的泛化能力,提高模型的稳定。相对于手工探索式的参数调试,交叉验证更具备统计学上的意义。
    • 在Spark中,Cross Validation和ParamMap(“参数组合”的Map)结合使用。具体做法是,针对特定的Param组合,CrossValidator计算K (K 折交叉验证)个评估分数的平均值。然后和其它“参数组合”CrossValidator计算结果比较,完成所有的比较后,将最优的“参数组合”挑选出来,这“最优的一组参数”将用在整个训练数据集上重新训练(re-fit),得到最终的Model。
    • 也就是说,通过交叉验证,找到了最佳的”参数组合“,利用这组参数,在整个训练集上可以训练(fit)出一个泛化能力强,误差相对最小的的最佳模型。
    • 很显然,交叉验证计算代价很高,假设有三个参数:参数alpha有3中选择,参数beta有4种选择,参数gamma有4中选择,进行10折计算,那么将进行(3×4×4)×10=480次模型训练。

    Spark documnets 原文: 
    (1)CrossValidator begins by splitting the dataset into a set of folds which are used as separate training and test datasets. E.g., with k=3folds, CrossValidator will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. To evaluate a particular ParamMap, CrossValidator computes the average evaluation metric for the 3 Models produced by fitting the Estimator on the 3 different (training, test) dataset pairs. 
    (2)After identifying the best ParamMap, CrossValidator finally re-fits the Estimator using the best ParamMap and the entire dataset. 
    (3)Using CrossValidator to select from a grid of parameters.Note that cross-validation over a grid of parameters is expensive. E.g., in the example below, the parameter grid has 3 values for hashingTF.numFeatures and 2 values for lr.regParam, and CrossValidator uses 2 folds. This multiplies out to (3×2)×2=12different models being trained. In realistic settings, it can be common to try many more parameters and use more folds (k=3 and k=10 are common). In other words, using CrossValidator can be very expensive. However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning.

    计算流程

    //Spark Version 2.0
    package my.spark.ml.practice;
    
    import java.io.IOException;
    
    import org.apache.log4j.Level;
    import org.apache.log4j.Logger;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.ml.Pipeline;
    import org.apache.spark.ml.PipelineStage;
    import org.apache.spark.ml.evaluation.RegressionEvaluator;
    import org.apache.spark.ml.param.ParamMap;
    import org.apache.spark.ml.recommendation.ALS;
    import org.apache.spark.ml.tuning.CrossValidator;
    import org.apache.spark.ml.tuning.CrossValidatorModel;
    import org.apache.spark.ml.tuning.ParamGridBuilder;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.SparkSession;
    
    /**ALS算法协同过滤推荐算法
     * 使用Spark 2.0 基于Pipeline,ParamMap,CrossValidation
     * 对超参数进行调优,并进行模型选择
     */
    
    public class MyCrossValidation {
      public static void main(String[] args) throws IOException{
          SparkSession spark=SparkSession
                  .builder()
                  .appName("myCrossValidation")
                  .master("local[4]")
                  .getOrCreate();
        //屏蔽日志
          Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
          Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF); 
        //加载数据
          JavaRDD<Rating> ratingsRDD = spark
                  .read().textFile("/home/hadoop/spark/spark-2.0.0-bin-hadoop2.6" +
                        "/data/mllib/als/sample_movielens_ratings.txt").javaRDD()
                  .map(new Function<String, Rating>() {
                      public Rating call(String str) {
                          return Rating.parseRating(str);
                      }
                  });
          //将整个数据集划分为训练集和测试集
          //注意training集将用于Cross Validation,而test集将用于最终模型的评估
          //在traning集中,在Croos Validation时将进一步划分为K份,每次留一份作为
          //Validation,注意区分:ratings.randomSplit()分出的Test集和K 折留
          //下验证的那一份完全不是一个概念,也起着完全不同的作用,一定不要相混淆
          Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
          Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
          Dataset<Row> training = splits[0];
          Dataset<Row> test = splits[1];
    
          // Build the recommendation model using ALS on the training data
          ALS als=new ALS()
                  .setMaxIter(8)
                  .setRank(20).setRegParam(0.8)
                  .setUserCol("userId")
                  .setItemCol("movieId")
                  .setRatingCol("rating")
                  .setPredictionCol("predict_rating");
          /*
           * (1)秩Rank:模型中隐含因子的个数:低阶近似矩阵中隐含特在个数,因子一般多一点比较好,
           * 但是会增大内存的开销。因此常在训练效果和系统开销之间进行权衡,通常取值在10-200之间。
           * (2)最大迭代次数:运行时的迭代次数,ALS可以做到每次迭代都可以降低评级矩阵的重建误差,
           * 一般少数次迭代便能收敛到一个比较合理的好模型。
           * 大部分情况下没有必要进行太对多次迭代(10次左右一般就挺好了)
           * (3)正则化参数regParam:和其他机器学习算法一样,控制模型的过拟合情况。
           * 该值与数据大小,特征,系数程度有关。此参数正是交叉验证需要验证的参数之一。
           */
          // Configure an ML pipeline, which consists of one stage
          //一般会包含多个stages
          Pipeline pipeline=new Pipeline().
                  setStages(new PipelineStage[] {als});
          // We use a ParamGridBuilder to construct a grid of parameters to search over.
          ParamMap[] paramGrid=new ParamGridBuilder()
          .addGrid(als.rank(),new int[]{5,10,20})
          .addGrid(als.regParam(),new double[]{0.05,0.10,0.15,0.20,0.40,0.80})
          .build();
    
          // CrossValidator 需要一个Estimator,一组Estimator ParamMaps, 和一个Evaluator.
          // (1)Pipeline作为Estimator;
          // (2)定义一个RegressionEvaluator作为Evaluator,并将评估标准设置为“rmse”均方根误差
          // (3)设置ParamMap
          // (4)设置numFolds    
    
          CrossValidator cv=new CrossValidator()
          .setEstimator(pipeline)
          .setEvaluator(new RegressionEvaluator()
                  .setLabelCol("rating")
                  .setPredictionCol("predict_rating")
                  .setMetricName("rmse"))
          .setEstimatorParamMaps(paramGrid)
          .setNumFolds(5);
    
          // 运行交叉检验,自动选择最佳的参数组合
          CrossValidatorModel cvModel=cv.fit(training);
          //保存模型
          cvModel.save("/home/hadoop/spark/cvModel_als.modle");
    
          //System.out.println("numFolds: "+cvModel.getNumFolds());
          //Test数据集上结果评估  
          Dataset<Row> predictions=cvModel.transform(test);
          RegressionEvaluator evaluator = new RegressionEvaluator()
          .setMetricName("rmse")//RMS Error
          .setLabelCol("rating")
          .setPredictionCol("predict_rating");
          Double rmse = evaluator.evaluate(predictions);
          System.out.println("RMSE @ test dataset " + rmse);
          //Output: RMSE @ test dataset 0.943644792277118
      }   
    }
    

    备注:程序运行需要定义Rating Class 在下面链接里可以找到: http://spark.apache.org/docs/latest/ml-collaborative-filtering.html

  • 相关阅读:
    window 编译lua 5.3
    邮件服务器软件
    mkyaffs2image 生成不了120M的镜像文件的解决方法
    C static struct
    uboot 如何向内核传递参数
    linux 链接理解
    snmp 协议之理解
    交叉编译知识点总结
    回滚原理 Since database connections are thread-local, this is thread-safe.
    REST 架构的替代方案 为什么说GraphQL是API的未来?
  • 原文地址:https://www.cnblogs.com/itboys/p/8310134.html
Copyright © 2011-2022 走看看