zoukankan      html  css  js  c++  java
  • xgboost的SparkWithDataFrame版本实现

      再xgboost的源码中有xgboost的SparkWithDataFrame的实现,如下:https://github.com/dmlc/xgboost/tree/master/jvm-packages。但是由于各种各样的原因吧,这些代码在我的IDE里面编译不过,因此又写了如下代码以供以后查阅使用。

    package xgboost
    
    import ml.dmlc.xgboost4j.scala.spark.{XGBoost, XGBoostModel}
    import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
    import org.apache.spark.sql.{Row, DataFrame, SparkSession}
    
    object App{
      def main(args: Array[String]): Unit ={
        val trainPath: String = "xxx/train.txt"
        val testPath: String = "xxx/test.txt"
        val binaryModelPath: String = "xxx/model.binary"
        val textModelPath: String = "xxx/model.txt"
        val spark = SparkSession
          .builder()
          .master("yarn")
          .getOrCreate()
    
        // define xgboost parameters
        val maxDepth = 3
        val numRound = 4
        val nworker = 1
        val paramMap = List(
          "eta" -> 0.1,
          "max_depth" -> maxDepth,
          "objective" -> "binary:logistic").toMap
    
        //read libsvm file
        var dfTrain = spark.read.format("libsvm").load(trainPath).toDF("labelCol", "featureCol")
        var dfTest = spark.read.format("libsvm").load(testPath).toDF("labelCol", "featureCol")
        dfTrain.show(true)
        printf("begin...")
        val model:XGBoostModel = XGBoost.trainWithDataFrame(dfTrain, paramMap, numRound, nworker,
          useExternalMemory = true,
          featureCol = "featureCol", labelCol = "labelCol",
          missing = 0.0f)
    
        //predict the test set
        val predict:DataFrame = model.transform(dfTest)
        val scoreAndLabels = predict.select(model.getPredictionCol, model.getLabelCol)
          .rdd
          .map{case Row(score:Double, label:Double) => (score, label)}
    
        //get the auc
        val metric = new BinaryClassificationMetrics(scoreAndLabels)
        val auc = metric.areaUnderROC()
        println("auc:" + auc)
    
        //save model
        this.saveBinaryModel(model, spark, binaryModelPath)
        this.saveTextModel(model, spark, textModelPath, numRound, maxDepth)
      }
    
      def saveBinaryModel(model:XGBoostModel, spark: SparkSession, path: String): Unit = {
        model.saveModelAsHadoopFile(path)(spark.sparkContext)
      }
    
      def saveTextModel(model:XGBoostModel, spark: SparkSession, path: String, numRound: Int, maxDepth: Int): Unit = {
        val dumpModel = model
          .booster
          .getModelDump()
          .toList
          .zipWithIndex
          .map(x => s"booster:[${x._2}]
    ${x._1}")
    
        val header = s"numRound: $numRound, maxDepth: $maxDepth"
        print(dumpModel)
        import spark.implicits._
        val text: List[String] = header +: dumpModel
          text.toDF
            .coalesce(1)
            .write
            .mode("overwrite")
            .text(path)
      }
    }
    

      其中:

      1.训练集和测试集都是libsvm格式,如下所示:

    1 3:1 10:1 11:1 21:1 30:1 34:1 36:1 40:1 41:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 117:1 124:1
    0 3:1 10:1 20:1 21:1 23:1 34:1 36:1 39:1 41:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 120:1

      2.最终生成的模型如下所示:

    numRound: 4, maxDepth: 3
    booster:[0]
    0:[f29<2] yes=1,no=2,missing=2
        1:leaf=0.152941
        2:leaf=-0.191209
    
    booster:[1]
    0:[f29<2] yes=1,no=2,missing=2
        1:leaf=0.141901
        2:leaf=-0.174499
    
    booster:[2]
    0:[f29<2] yes=1,no=2,missing=2
        1:leaf=0.132731
        2:leaf=-0.161685
    
    booster:[3]
    0:[f29<2] yes=1,no=2,missing=2
        1:leaf=0.124972
        2:leaf=-0.15155

      相关解释:”numRound: 4, maxDepth: 3”表示生成树的个数为4,树的最大深度为3;booster[n]表示第n棵树;以下保存树的结构,0号节点为根节点,每个节点有两个子节点,节点序号按层序技术,即1号和2号节点为根节点0号节点的子节点,相同层的节点有相同缩进,且比父节点多一级缩进。
      在节点行,首先声明节点序号,中括号里写明该节点采用第几个特征(如f29即为训练数据的第29个特征),同时表明特征值划分条件,“[f29<2] yes=1,no=2,missing=2”:表示f29号特征大于2时该样本划分到1号叶子节点,f29>=2时划分到2号叶子节点,当没有该特征(None)划分到2号叶子节点。

      3.预测的结果如下:

    |labelCol|featureCol                                                                                                                                                  |probabilities                          |prediction|
    |1.0     |(126,[2,9,10,20,29,33,35,39,40,52,57,64,68,76,85,87,91,94,101,104,116,123],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|[0.3652743101119995,0.6347256898880005]|1.0       |
    |0.0     |(126,[2,9,19,20,22,33,35,38,40,52,55,64,68,76,85,87,91,94,101,105,115,119],[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0])|[0.6635029911994934,0.3364970088005066]|0.0       |
    

      

  • 相关阅读:
    实习生Python炫技却被主管教育?原来是这样!
    Python炫技操作却被骂,为啥?
    你要是能学会这招,还能没有小姐姐吗!
    用Python快速从深层嵌套 JSON 中找到特定的 Key
    哪儿网领域驱动设计(DDD)实践之路 Qunar技术沙龙 2021-05-11
    闲鱼单体应用Serverless化拆分实践 原创 柬超 闲鱼技术 今天
    // context canceled ctx := context.Background()
    Virtual DOM(虚拟DOM)
    新一代Web技术栈的演进:SSR/SSG/ISR/DPR都在做什么?
    延迟队列浅析 原创 张浩 网易传媒技术团队 2019-08-02
  • 原文地址:https://www.cnblogs.com/zhaochunhua/p/6723660.html
Copyright © 2011-2022 走看看