zoukankan      html  css  js  c++  java
  • lkl风控.随机森林模型测试代码spark1.6

    /**
      * Created by lkl on 2017/10/9.
      */
    import org.apache.spark.sql.hive.HiveContext
    import org.apache.spark.SparkConf
    import scala.collection.mutable.ArrayBuffer
    import org.apache.spark.SparkContext
    import org.apache.spark.mllib.tree.RandomForest
    import org.apache.spark.mllib.tree.model.RandomForestModel
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.sql.SQLContext
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    object uvcy {
      def main(args: Array[String]) {
        val conf = new SparkConf().setAppName("test") //setMaster("spark://192.168.0.37:7077")
        val sc = new SparkContext(conf)
        val sqlContext = new SQLContext(sc)
        val hc = new HiveContext(sc)
        val data2 = hc.sql("select * from  fin_tec.uvcy2")
    //第一个字段为身份证号,第二个字段为是否逾期,字符存在在hive中全部为double型
        val data = data2.map{ row => val arr = new ArrayBuffer[Double]()
            for(i <- 2 until row.size){
              if(row.isNullAt(i)){
                arr += 0.0}
              else if(row.get(i).isInstanceOf[Double])
                arr += row.getDouble(i)
              else if(row.get(i).isInstanceOf[Long])
                arr += row.getLong(i).toDouble
              else if(row.get(i).isInstanceOf[String])
                arr += row.getString(i).toDouble}
            LabeledPoint(row.getDouble(1), Vectors.dense(arr.toArray))}
        val splits = data.randomSplit(Array(0.7, 0.3))
        val (trainingData, testData) = (splits(0), splits(1))
        val numClasses = 2
        val categoricalFeaturesInfo = Map[Int, Int]()
        val numTrees = 3
        val featureSubsetStrategy = "auto"
        val impurity = "gini"
        val maxDepth = 4
        val maxBins = 32
        val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
        val labelAndPreds = testData.map { point =>
          val prediction = model.predict(point.features)
          (point.label, prediction)
        }
        val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("precision")
        val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
        println("Test Error = " + testErr)
        println("Learned classification forest model:
    " + model.toDebugString)
        model.save(sc, "uvcymodel/forest")
    
        val sameModel = RandomForestModel.load(sc, "uvcymodel/forest")
        val data3 = hc.sql("select * from test.uvcy where i_l3_hk_amt=2150")
        val id="110101000000000000"
        val datas = data3.map{ row => val arr = new ArrayBuffer[Double]()
          for(i <- 2 until row.size){
            if(row.isNullAt(i)){
              arr += 0.0}
            else if(row.get(i).isInstanceOf[Double])
              arr += row.getDouble(i)
            else if(row.get(i).isInstanceOf[Long])
              arr += row.getLong(i).toDouble
            else if(row.get(i).isInstanceOf[String])
              arr += row.getString(i).toDouble}
          (Vectors.dense(arr.toArray))}
        val labelAndPreds2 = testData.map { point =>
          val prediction =sameModel.predict(point.features)
          (id,point.label, prediction,point.features)
        }
        labelAndPreds2.take(2)
    
    
    
    
    
    
      }
    }
    
  • 相关阅读:
    数据库
    流式布局
    ScrollView简单用法
    ADB被占用解决办法
    安卓中shape中的属性大全
    sql语句replace into的用法
    debug
    大数据量数据库优化
    Gson解析后的数据存到本地数据库 耗时的问题
    数据同步异步加载handler Looper
  • 原文地址:https://www.cnblogs.com/canyangfeixue/p/7762521.html
Copyright © 2011-2022 走看看