zoukankan      html  css  js  c++  java
  • 机器学习结果加ID插入数据库源码

    import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.mllib.tree.GradientBoostedTrees
    import org.apache.spark.mllib.tree.configuration.BoostingStrategy
    import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
    import org.apache.spark.sql.{Row, SaveMode}
    import org.apache.spark.sql.hive.HiveContext
    import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}
    import org.apache.spark.{SparkConf, SparkContext}
    import scala.collection.mutable.ArrayBuffer
    object v4score20180123 {
      def main(args: Array[String]): Unit = {
      val sparkConf = new SparkConf().setAppName("v4model20180123")
      val sc = new SparkContext(sparkConf)
      val hc = new HiveContext(sc)
    
      val dataInstance = hc.sql(s"select * from lkl_card_score.fqz_score_dataset_04vals").map {
        row =>
          val arr = new ArrayBuffer[Double]()
          //剔除label、phone字段
          for (i <- 3 until row.size) {
            if (row.isNullAt(i)) {
              arr += 0.0
            }
            else if (row.get(i).isInstanceOf[Int])
              arr += row.getInt(i).toDouble
            else if (row.get(i).isInstanceOf[Double])
              arr += row.getDouble(i)
            else if (row.get(i).isInstanceOf[Long])
              arr += row.getLong(i).toDouble
            else if (row.get(i).isInstanceOf[String])
              arr += 0.0
          }
          (row(0),row(1),row(2),Vectors.dense(arr.toArray))
      }
    
    
      val  modeltest=GradientBoostedTreesModel.load(sc,s"hdfs://ns1/user/songchunlin/model/v4model20180123s")
      val preditDataGBDT = dataInstance.map { point =>
        val prediction = modeltest.predict(point._4)
        //order_id,apply_time,score
        (point._1,point._2,point._3,prediction)
      }
      preditDataGBDT.take(5)
      //rdd转dataFrame
      val rowRDD = preditDataGBDT.map(row => Row(row._1.toString,row._2.toString,row._3.toString,row._4))
      val schema = StructType(
        List(
          StructField("order_id", StringType, true),
          StructField("apply_time", StringType, true),
          StructField("label", StringType, true),
          StructField("score", DoubleType, true)
        )
      )
      //将RDD映射到rowRDD,schema信息应用到rowRDD上
      val scoreDataFrame = hc.createDataFrame(rowRDD,schema)
      scoreDataFrame.count()
      scoreDataFrame.write.mode(SaveMode.Overwrite).saveAsTable("lkl_card_score.fqz_score_dataset_03val_v4_predict0123s")
    
    }
    }
    

      

  • 相关阅读:
    查找小岛个数
    非递归遍历树的总结(前中后序)
    Java的TreeMap,C++的lower_bound,合并间隔
    最多包含2/k个不同字符的最长串
    爆气球这道题目,展开了新的思路
    C++的hashmap和Java的hashmap
    求数组里重复出现的数字
    数组中出现一次的两个数(三个数)& 求最后一位bit为1
    皇后问题的经典做法
    海外省电应用市场:本土化为先锋,高技术为基础
  • 原文地址:https://www.cnblogs.com/canyangfeixue/p/8376498.html
Copyright © 2011-2022 走看看