zoukankan      html  css  js  c++  java
  • lakala GradientBoostedTrees

    /**
      * Created by lkl on 2017/12/6.
      */
    import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.mllib.tree.GradientBoostedTrees
    import org.apache.spark.mllib.tree.configuration.BoostingStrategy
    import org.apache.spark.sql.hive.HiveContext
    import org.apache.spark.{SparkConf, SparkContext}
    import scala.collection.mutable.ArrayBuffer
    object GradientBoostingClassificationForLK {
    //http://blog.csdn.net/xubo245/article/details/51499643
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf().setAppName("GradientBoostingClassificationForLK")
        val sc = new SparkContext(conf)
    
        // sc is an existing SparkContext.
        val hc = new HiveContext(sc)
    
        if(args.length!=3){
          println("请输入参数:trainingData对应的库名、表名、模型运行时间")
          System.exit(0)
        }
    
        //分别传入库名、表名、对比效果路径
    //    val database = args(0)
    //    val table = args(1)
    //    val date = args(2)
         //lkl_card_score.overdue_result_all_new_woe
         val format = new java.text.SimpleDateFormat("yyyyMMdd")
         val database ="lkl_card_score"
         val table = "overdue_result_all_new_woe"
         val date =format.format(new java.util.Date())
        //提取数据集 RDD[LabeledPoint]
        //val data = hc.sql(s"select * from $database.$table").map{
    
    
    
        val data = hc.sql(s"select * from lkl_card_score.overdue_result_all_new_woe").map{
          row =>
            var arr = new ArrayBuffer[Double]()
            //剔除label、contact字段
            for(i <- 3 until row.size){
              if(row.isNullAt(i)){
                arr += 0.0
              }
              else if(row.get(i).isInstanceOf[Int])
                arr += row.getInt(i).toDouble
              else if(row.get(i).isInstanceOf[Double])
                arr += row.getDouble(i)
              else if(row.get(i).isInstanceOf[Long])
                arr += row.getLong(i).toDouble
              else if(row.get(i).isInstanceOf[String])
                arr += 0.0
            }
            LabeledPoint(row.getInt(0), Vectors.dense(arr.toArray))
        }
        // Split the data into training and test sets (30% held out for testing)
        val splits = data.randomSplit(Array(0.7, 0.3))
        val (trainingData, testData) = (splits(0), splits(1))
    
        // Train a GradientBoostedTrees model.
        // The defaultParams for Classification use LogLoss by default.
        val boostingStrategy = BoostingStrategy.defaultParams("Classification")
        boostingStrategy.setNumIterations(3) // Note: Use more iterations in practice.
        boostingStrategy.treeStrategy.setNumClasses(2)
        boostingStrategy.treeStrategy.setMaxDepth(5)
        // Empty categoricalFeaturesInfo indicates all features are continuous.
        //boostingStrategy.treeStrategy.setCategoricalFeaturesInfo(Map[Int, Int]())
    
        val model = GradientBoostedTrees.train(trainingData, boostingStrategy)
    
        // Evaluate model on test instances and compute test error
        val predictionAndLabels = testData.map { point =>
          val prediction = model.predict(point.features)
          (point.label, prediction)
        }
    
        predictionAndLabels.map(x => {"predicts: "+x._1+"--> labels:"+x._2}).saveAsTextFile(s"hdfs://ns1/tmp/$date/predictionAndLabels")
        //===================================================================
        //使用BinaryClassificationMetrics评估模型
        val metrics = new BinaryClassificationMetrics(predictionAndLabels)
    
        // Precision by threshold
        val precision = metrics.precisionByThreshold
        precision.map({case (t, p) =>
          "Threshold: "+t+"Precision:"+p
        }).saveAsTextFile(s"hdfs://ns1/tmp/$date/precision")
    
        // Recall by threshold
        val recall = metrics.recallByThreshold
        recall.map({case (t, r) =>
          "Threshold: "+t+"Recall:"+r
        }).saveAsTextFile(s"hdfs://ns1/tmp/$date/recall")
    
        //the beta factor in F-Measure computation.
        val f1Score = metrics.fMeasureByThreshold
        f1Score.map(x => {"Threshold: "+x._1+"--> F-score:"+x._2+"--> Beta = 1"})
          .saveAsTextFile(s"hdfs://ns1/tmp/$date/f1Score")
    
        /**
          * 如果要选择Threshold, 这三个指标中, 自然F1最为合适
          * 求出最大的F1, 对应的threshold就是最佳的threshold
          */
        /*val maxFMeasure = f1Score.select(max("F-Measure")).head().getDouble(0)
        val bestThreshold = f1Score.where($"F-Measure" === maxFMeasure)
          .select("threshold").head().getDouble(0)*/
    
        // Precision-Recall Curve
        val prc = metrics.pr
        prc.map(x => {"Recall: " + x._1 + "--> Precision: "+x._2 }).saveAsTextFile(s"hdfs://ns1/tmp/$date/prc")
    
        // AUPRC,精度,召回曲线下的面积
        val auPRC = metrics.areaUnderPR
        sc.makeRDD(Seq("Area under precision-recall curve = " +auPRC)).saveAsTextFile(s"hdfs://ns1/tmp/$date/auPRC")
    
        //roc
        val roc = metrics.roc
        roc.map(x => {"FalsePositiveRate:" + x._1 + "--> Recall: " +x._2}).saveAsTextFile(s"hdfs://ns1/tmp/$date/roc")
    
        // AUC
        val auROC = metrics.areaUnderROC
        sc.makeRDD(Seq("Area under ROC = " + +auROC)).saveAsTextFile(s"hdfs://ns1/tmp/$date/auROC")
        println("Area under ROC = " + auROC)
    
        val testErr = predictionAndLabels.filter(r => r._1 != r._2).count.toDouble / testData.count()
        sc.makeRDD(Seq("Test Mean Squared Error = " + testErr)).saveAsTextFile(s"hdfs://ns1/tmp/$date/testErr")
        sc.makeRDD(Seq("Learned regression tree model: " + model.toDebugString)).saveAsTextFile(s"hdfs://ns1/tmp/$date/GBDTclassification")
      }
    
    }
  • 相关阅读:
    [NOI2002]银河英雄传说
    Splay普及版
    线段树普及版
    长连接与短连接
    【HTTP】中Get/Post请求区别
    【HTML】知识笔记
    SVN使用教程总结
    《人生只有一次,去做自己喜欢的事》读书笔记
    【HTTP】无状态无连接的含义
    【HTML】解析原理
  • 原文地址:https://www.cnblogs.com/canyangfeixue/p/8006103.html
Copyright © 2011-2022 走看看