zoukankan      html  css  js  c++  java
  • lkl风控.随机森林模型测试代码spark1.6

    /**
      * Created by lkl on 2017/10/9.
      */
    import org.apache.spark.sql.hive.HiveContext
    import org.apache.spark.SparkConf
    import scala.collection.mutable.ArrayBuffer
    import org.apache.spark.SparkContext
    import org.apache.spark.mllib.tree.RandomForest
    import org.apache.spark.mllib.tree.model.RandomForestModel
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.sql.SQLContext
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    object uvcy {
      def main(args: Array[String]) {
        val conf = new SparkConf().setAppName("test") //setMaster("spark://192.168.0.37:7077")
        val sc = new SparkContext(conf)
        val sqlContext = new SQLContext(sc)
        val hc = new HiveContext(sc)
        val data2 = hc.sql("select * from  fin_tec.uvcy2")
    //第一个字段为身份证号,第二个字段为是否逾期,字符存在在hive中全部为double型
        val data = data2.map{ row => val arr = new ArrayBuffer[Double]()
            for(i <- 2 until row.size){
              if(row.isNullAt(i)){
                arr += 0.0}
              else if(row.get(i).isInstanceOf[Double])
                arr += row.getDouble(i)
              else if(row.get(i).isInstanceOf[Long])
                arr += row.getLong(i).toDouble
              else if(row.get(i).isInstanceOf[String])
                arr += row.getString(i).toDouble}
            LabeledPoint(row.getDouble(1), Vectors.dense(arr.toArray))}
        val splits = data.randomSplit(Array(0.7, 0.3))
        val (trainingData, testData) = (splits(0), splits(1))
        val numClasses = 2
        val categoricalFeaturesInfo = Map[Int, Int]()
        val numTrees = 3
        val featureSubsetStrategy = "auto"
        val impurity = "gini"
        val maxDepth = 4
        val maxBins = 32
        val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
        val labelAndPreds = testData.map { point =>
          val prediction = model.predict(point.features)
          (point.label, prediction)
        }
        val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("precision")
        val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
        println("Test Error = " + testErr)
        println("Learned classification forest model:
    " + model.toDebugString)
        model.save(sc, "uvcymodel/forest")
    
        val sameModel = RandomForestModel.load(sc, "uvcymodel/forest")
        val data3 = hc.sql("select * from test.uvcy where i_l3_hk_amt=2150")
        val id="110101000000000000"
        val datas = data3.map{ row => val arr = new ArrayBuffer[Double]()
          for(i <- 2 until row.size){
            if(row.isNullAt(i)){
              arr += 0.0}
            else if(row.get(i).isInstanceOf[Double])
              arr += row.getDouble(i)
            else if(row.get(i).isInstanceOf[Long])
              arr += row.getLong(i).toDouble
            else if(row.get(i).isInstanceOf[String])
              arr += row.getString(i).toDouble}
          (Vectors.dense(arr.toArray))}
        val labelAndPreds2 = testData.map { point =>
          val prediction =sameModel.predict(point.features)
          (id,point.label, prediction,point.features)
        }
        labelAndPreds2.take(2)
    
    
    
    
    
    
      }
    }
    
  • 相关阅读:
    解疑网络监控卡壳 视觉体验400ms延时
    DDD实战11 在项目中使用JWT的token 进行授权验证
    在sql语句中 inner join ,left join,right join 和on 以及where
    在.net core自带DI中服务生命周期 Transient,Scoped,Singleton
    DDD实战10 在项目中使用JWT的token
    简单使用.net core 自带的DI
    在.net core中一个简单的加密算法
    DDD实战9 经销商领域上下文
    DDD实战8_2 利用Unity依赖注入,实现接口对应实现类的可配置
    DDD实战8_1 实现对领域中连接字符串的可配置
  • 原文地址:https://www.cnblogs.com/canyangfeixue/p/7762521.html
Copyright © 2011-2022 走看看