zoukankan      html  css  js  c++  java
  • SparkMllib分类问题的模板代码

    • 需求:对数据进行分类问题的处理

    • 开发步骤:

      • 1-准备SparkSession的环境
      • 2-准备大数据的数据
      • 3-读取数据并进行解析
      • 4-数据的基本信息的查看
      • 5-特征工程
      • 6-准备算法
      • 7-模型训练
      • 8-模型预测
      • 9-模型校验
      • 10-模型保存
      • 11-新数据预测
    • 代码模板:

    import org.apache.spark.SparkConf
    import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.feature._
    import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
    
    /**
      * DESC: 对分类问题的模板的代码
      * Complete data processing and modeling process steps:
      *- 1-准备SparkSession的环境
      *- 2-准备大数据的数据
      *- 3-读取数据并进行解析
      *- 4-数据的基本信息的查看
      *- 5-特征工程
      *- 6-准备算法
      *- 7-模型训练
      *- 8-模型预测
      *- 9-模型校验
      *- 10-模型保存
      *- 11-新数据预测
      *
      */
    object ClassficationModelTest {
    
      var datapath = "D:\BigData\Workspace\SparkMachineLearningTest\SparkMllib_BigData32\src\main\resources\iris.csv"
    
      def main(args: Array[String]): Unit = {
        //    - 1-准备SparkSession的环境
        val conf: SparkConf = new SparkConf().setAppName("ClassficationModelTest").setMaster("local[*]")
        val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
        spark.sparkContext.setLogLevel("WARN")
        //    - 2-准备大数据的数据
        val irisDF: DataFrame = spark.read.format("csv")
          .option("header", true)
          .option("inferschema", true)
          .option("sep", ",")
          .load(datapath)
        //    - 3-读取数据并进行解析
        irisDF.show(10, false)
        //    +------------+-----------+------------+-----------+-----------+
        //    |sepal_length|sepal_width|petal_length|petal_width|class      |
        //    +------------+-----------+------------+-----------+-----------+
        //    |5.1         |3.5        |1.4         |0.2        |Iris-setosa|
        //      |4.9         |3.0        |1.4         |0.2        |Iris-setosa|
        //      |4.7         |3.2        |1.3         |0.2        |Iris-setosa|
        //      |4.6         |3.1        |1.5         |0.2        |Iris-setosa|
        //    - 4-数据的基本信息的查看
        irisDF.printSchema()
        // 因为在写各种string类型数据的时候可能会有一些单词拼写错误,可以实现定义
        val sepal_length_feeature = "sepal_length"
        val sepal_width_feeature = "sepal_width"
        val petal_length_feeature = "petal_length"
        val petal_width_feeature = "petal_width"
        val class_label = "class"
        //    root
        //    |-- sepal_length: double (nullable = true)
        //    |-- sepal_ double (nullable = true)
        //    |-- petal_length: double (nullable = true)
        //    |-- petal_ double (nullable = true)
        //    |-- class: string (nullable = true)
        //    - 5-特征工程
        //5-1处理类别型的数据class
        val stringIndexer: StringIndexer = new StringIndexer()
          .setInputCol(class_label)
          .setOutputCol("classlabel")
        val stringIndexerModel: StringIndexerModel = stringIndexer.fit(irisDF)
        val indexDF: DataFrame = stringIndexerModel.transform(irisDF)
        //5-2处理分散的特征整合为特征向量
        val vectorAssembler: VectorAssembler = new VectorAssembler()
          .setInputCols(Array(sepal_length_feeature, sepal_width_feeature, petal_length_feeature, petal_width_feeature))
          .setOutputCol("features")
        val vecDF: DataFrame = vectorAssembler.transform(indexDF)
        //5-3VectorIndexer对类别值的索引化,加速构建决策树
        val vectorIndexer: VectorIndexer = new VectorIndexer()
          .setInputCol("features")
          .setOutputCol("vecindexFeatures")
          .setMaxCategories(20)
        val vectorIndexerModel: VectorIndexerModel = vectorIndexer.fit(vecDF)
        val vecindexerDF: DataFrame = vectorIndexerModel.transform(vecDF)
        vecindexerDF.show(10, false)
        //    - 6-准备算法
        val classifier: DecisionTreeClassifier = new DecisionTreeClassifier()
          .setLabelCol("classlabel")
          .setPredictionCol("prces")
          .setFeaturesCol("vecindexFeatures")
          .setMaxDepth(5)
          .setImpurity("gini")
        val Array(trainingSet, testSet): Array[Dataset[Row]] = vecindexerDF.randomSplit(Array(0.8, 0.2), seed = 1234L)
        //    - 7-模型训练
        val model: DecisionTreeClassificationModel = classifier.fit(trainingSet)
        //    - 8-模型预测
        val y_pred_train: DataFrame = model.transform(trainingSet)
        val y_pred_test: DataFrame = model.transform(testSet)
        y_pred_train.show(10, false)
        //    - 9-模型校验
        val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
          //"(f1|weightedPrecision|weightedRecall|accuracy)"
          .setMetricName("accuracy")
          .setPredictionCol("prces")
          .setLabelCol("classlabel")
        val acc_test: Double = evaluator.evaluate(y_pred_test)
        val acc_train: Double = evaluator.evaluate(y_pred_train)
        println("acc in trainset score is:", acc_train)
        println("acc in testset score is:", acc_test)
        //    (acc in trainset score is:,0.9920634920634921)
        //    (acc in testset score is:,0.9583333333333334)
        //    //    - 10-模型保存
        //    val datapath="D:\BigData\Workspace\SparkMachineLearningTest\SparkMllib_BigData32\src\main\resources\model1"
        //    model.save(datapath)
        //    //    - 11-新数据预测
        //    DecisionTreeClassificationModel.load(datapath)
    
      }
    }
    
  • 相关阅读:
    C++ Primer Plus 第15章 友元、异常和其它
    03013_JDBC工具类
    python GUI编程(Tkinter)
    Python2.x与3​​.x版本区别
    【python教程】Python JSON
    【python教程】Python IDE
    通过Google Custom Search API 进行站内搜索
    支持wmv、mpg、mov、avi格式的网页视频播放代码
    编写更好的jQuery代码的建议
    KindEditor得不到textarea值的解决方法
  • 原文地址:https://www.cnblogs.com/haojia/p/12396975.html
Copyright © 2011-2022 走看看