zoukankan      html  css  js  c++  java
  • Spark快速获得CrossValidator的最佳模型参数

    Spark提供了便利的Pipeline模型,可以轻松的创建自己的学习模型。

    但是大部分模型都是需要提供参数的,如果不提供就是默认参数,那么怎么选择参数就是一个比较常见的问题。Spark提供在org.apache.spark.ml.tuning包下提供了模型选择器,可以替换参数然后比较模型输出。

    目前有CrossValidator和TrainValidationSplit两种,比如一个文本情感预测模型。

    Pipeline只有三步,第一步切词,第二步HashingTF,第三步NB分类

    Pipeline pipeline = new Pipeline()
                    .setStages(new PipelineStage[]{tokenizer, hashingTF, naiveBayes});
    
    ParamMap[] paramMaps = new ParamGridBuilder()
                    .addGrid(hashingTF.numFeatures(), new int[]{10000, 100000, 500000, 1000000})
                    .build();
    CrossValidator cv = new CrossValidator()
                    .setEstimator(pipeline)
                    .setEvaluator(new BinaryClassificationEvaluator())
                    .setEstimatorParamMaps(paramMaps);

    其中HashingTF的参数选择非常重要,我们这里就随便尝试几种,然后放在CrossValidator中去。

    最后我们会获得一个CrossValidatorModel类,这里有两种选择。

    第一种是自己手动获取其中的参数,因为bestModel的参数就是我们最后选择的参数

    Pipeline bestPipeline = (Pipeline) model.bestModel().parent();
    PipelineStage stage = bestPipeline.getStages()[1];
    stage.extractParamMap().get(stage.getParam("numFeatures"));

    这种方法可以获得值,但是需要根据你模型情况修改获取的位置。

    如果你只是想知道最佳参数是多少,并不是需要在上下文中使用,那还有一个更简单的方法。

    修改log4j的配置,添加

    log4j.logger.org.apache.spark.ml.tuning.TrainValidationSplit=INFO
    log4j.logger.org.apache.spark.ml.tuning.CrossValidator=INFO

    效果如下:

  • 相关阅读:
    Winform中让回车键完成TAB键的功能
    ASP.NET跨页传值方法汇总
    SQL SERVER中使用Unicode字符的注意问题
    如何为Oracle配置多个监听器
    如何实现上一条、下一条的功能
    "文件中的备份集是由BACKUP DATABASE...FILE=创建的,无法用于此还原操作"的解决办法
    [psp][lumines]dat数据包解包程序
    meteos@pc, the remake制作中...
    最近在仿照Lumines写
    建立huffman树,当然用堆排序
  • 原文地址:https://www.cnblogs.com/itboys/p/9827567.html
Copyright © 2011-2022 走看看