zoukankan      html  css  js  c++  java
  • lkl风控.随机森林模型测试代码spark1.6

    /**
      * Created by lkl on 2017/10/9.
      */
    import org.apache.spark.sql.hive.HiveContext
    import org.apache.spark.SparkConf
    import scala.collection.mutable.ArrayBuffer
    import org.apache.spark.SparkContext
    import org.apache.spark.mllib.tree.RandomForest
    import org.apache.spark.mllib.tree.model.RandomForestModel
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.sql.SQLContext
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    object uvcy {
      def main(args: Array[String]) {
        val conf = new SparkConf().setAppName("test") //setMaster("spark://192.168.0.37:7077")
        val sc = new SparkContext(conf)
        val sqlContext = new SQLContext(sc)
        val hc = new HiveContext(sc)
        val data2 = hc.sql("select * from  fin_tec.uvcy2")
    //第一个字段为身份证号,第二个字段为是否逾期,字符存在在hive中全部为double型
        val data = data2.map{ row => val arr = new ArrayBuffer[Double]()
            for(i <- 2 until row.size){
              if(row.isNullAt(i)){
                arr += 0.0}
              else if(row.get(i).isInstanceOf[Double])
                arr += row.getDouble(i)
              else if(row.get(i).isInstanceOf[Long])
                arr += row.getLong(i).toDouble
              else if(row.get(i).isInstanceOf[String])
                arr += row.getString(i).toDouble}
            LabeledPoint(row.getDouble(1), Vectors.dense(arr.toArray))}
        val splits = data.randomSplit(Array(0.7, 0.3))
        val (trainingData, testData) = (splits(0), splits(1))
        val numClasses = 2
        val categoricalFeaturesInfo = Map[Int, Int]()
        val numTrees = 3
        val featureSubsetStrategy = "auto"
        val impurity = "gini"
        val maxDepth = 4
        val maxBins = 32
        val model = RandomForest.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
        val labelAndPreds = testData.map { point =>
          val prediction = model.predict(point.features)
          (point.label, prediction)
        }
        val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("precision")
        val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
        println("Test Error = " + testErr)
        println("Learned classification forest model:
    " + model.toDebugString)
        model.save(sc, "uvcymodel/forest")
    
        val sameModel = RandomForestModel.load(sc, "uvcymodel/forest")
        val data3 = hc.sql("select * from test.uvcy where i_l3_hk_amt=2150")
        val id="110101000000000000"
        val datas = data3.map{ row => val arr = new ArrayBuffer[Double]()
          for(i <- 2 until row.size){
            if(row.isNullAt(i)){
              arr += 0.0}
            else if(row.get(i).isInstanceOf[Double])
              arr += row.getDouble(i)
            else if(row.get(i).isInstanceOf[Long])
              arr += row.getLong(i).toDouble
            else if(row.get(i).isInstanceOf[String])
              arr += row.getString(i).toDouble}
          (Vectors.dense(arr.toArray))}
        val labelAndPreds2 = testData.map { point =>
          val prediction =sameModel.predict(point.features)
          (id,point.label, prediction,point.features)
        }
        labelAndPreds2.take(2)
    
    
    
    
    
    
      }
    }
    
  • 相关阅读:
    Cocos2d-js官方完整项目教程翻译:六、添加Chipmunk物理引擎在我们的游戏世界里
    linux coreseek-4.1安装
    8个必备的PHP功能开发
    LINUX 下mysql导出数据、表结构
    PHP缩略图类
    PHP文件上传类
    PHP抓取页面的几种方式
    MySQL性能优化的最佳20+条经验
    zend studio9.0.3破解及汉化 windons版
    【转载】【面试经验】PHP中级面试题
  • 原文地址:https://www.cnblogs.com/canyangfeixue/p/7762521.html
Copyright © 2011-2022 走看看