zoukankan      html  css  js  c++  java
  • spark 机器学习(ml pipeline)

    1、业务目标,通过训练模型给待处理数据打上标签

    给定训练样本中对包含hello的字符串文本打上标签1,否则打上0.
    期望,通过训练模型用机器学习的方式对待测试数据做同样的操作。

    2、训练样本sample.txt

    三列(id,文本,标签),hello文本标签为1
    0,why hello world JAVA,1.0
    1,what llo java jsp,0.0
    2,test hello2 scala,0.0
    3,abc spark hello,1.0
    4,j hello c#,1.0
    5,i java hell spark,0.0
    6,i java hell spark

    3、待测试数据样本w1.txt

    0,hello world
    1,hello java test num
    2,test hello scala
    3,j hello spark
    4,abc hello c#
    5,hell java spark
    6,hello java spark
    7,num he he java spark
    8,hello2 java spark
    9,hello do some thing java spark
    10,world hello java spark

    4、code

    4.1依赖

            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-core_2.11</artifactId>
                <version>2.4.4</version>
            </dependency>
    
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-sql_2.11</artifactId>
                <version>2.4.4</version>
            </dependency>
    
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-streaming_2.11</artifactId>
                <version>2.4.4</version>
            </dependency>
    
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-mllib_2.11</artifactId>
                <version>2.4.4</version>
            </dependency>

    4.2 实现

    package com.home.spark.ml
    
    import org.apache.spark.SparkConf
    import org.apache.spark.ml.{Pipeline, PipelineModel}
    import org.apache.spark.ml.classification.LogisticRegression
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.{DataFrame, Row, SparkSession}
    import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
    import org.apache.spark.ml.linalg.Vector
    
    /**
      * @Description: 机器学习,训练样本数据,给生产数据打标签
      *              样本训练数据中带有hello的文本,打标签为1,否则为0
      *              通过训练模型,我们希望待测试数据同样用这种方式打上标签。
      **/
    object Ex_label {
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf(true).setMaster("local[*]").setAppName("spark ml label")
        val spark = SparkSession.builder().config(conf).getOrCreate()
        val error_count = spark.sparkContext.longAccumulator("error_count")
    
        //载入训练数据,数据手工训练,给带有hello的数据打上1.0的标签,给没有hello的数据打上0.0
        val lineRDD: RDD[String] = spark.sparkContext.textFile("input/sample.txt")
    
        //rdd转换成df或者ds需要SparkSession实例的隐式转换
        //导入隐式转换,注意这里的spark不是包名,而是SparkSession的对象名
        import spark.implicits._
    
    
        //生成训练数据,标签数据必须为double
        val training: DataFrame = lineRDD.map(line => {
          val strings: Array[String] = line.split(",")
          if (strings.length == 3) {
            (strings(0), strings(1), strings(2).toDouble)
          }
          else {
            error_count.add(1)
            ("-1", strings.mkString(" "), 0.0)
          }
    
        }).filter(s => !s._1.equals("-1"))
          .toDF("id", "text", "label")
    
        training.printSchema()
        training.show()
    
        println(s"错误数据计数 : ${error_count.value}")
    
    
        //Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
        //Transformer,转换器,字符解析,转换输入文本,以空格分隔,转成小写词
        val tokenizer: Tokenizer = new Tokenizer()
          .setInputCol("text")
          .setOutputCol("words")
    
        //Transformer,转换器,哈希转换,以哈希方式将词转换成词频,转成特征向量
        val hashTF: HashingTF = new HashingTF().setNumFeatures(1000)
          .setInputCol(tokenizer.getOutputCol).setOutputCol("features")
    
        //Estimator,预测器或评估器,逻辑回归,10次最大迭代
        val lr: LogisticRegression = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
    
        //预测器通过 fit() 方法,接收一个 DataFrame 并产出一个模型
        //封装流水线,包含两个转换器(实际包含两个模型),一个评估器(包含一个算法)
        //因为还有评估器,所以需要训练生成最终模型
        val pipeline: Pipeline = new Pipeline().setStages(Array(tokenizer, hashTF, lr))
    
    
        // Fit the pipeline to training documents.
        //训练,生成最终模型
        val model: PipelineModel = pipeline.fit(training)
    
        // 可以选择保存模型到磁盘
        model.write.overwrite().save("/tmp/spark-logistic-regression-model")
        // 重新加载回来
        //    val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model")
    
    
        // 保存未训练(unfit)的流水线到底盘
        //    pipeline.write.overwrite().save("/tmp/unfit-lr-model")
    
        //重新加载流水线
        //    val samePipeline = Pipeline.load("/tmp/unfit-lr-model")
    
    
        //加载待分析数据
        val testRDD: RDD[String] = spark.sparkContext.textFile("input/w1.txt")
        val test: DataFrame = testRDD.map(line => {
          val strings: Array[String] = line.split(",")
          if (strings.length == 2) {
            (strings(0), strings(1))
          }
          else {
            //        error_count.add(1)
            ("-1", strings.mkString(" "))
          }
    
        }).filter(s => !s._1.equals("-1"))
          .toDF("id", "text")
    
        //对给定数据进行预测
        model.transform(test)
          .select("id", "text", "probability", "prediction")
          .collect()
          .foreach {
            case Row(id: String, text: String, prob: Vector, prediction: Double) =>
              println(s"($id, $text) --> prob=$prob, prediction=$prediction")
          }
    
        spark.stop()
    
      }
    }
    
    /* 运行结果
    (0, hello world) --> prob=[0.02467400198786794,0.975325998012132], prediction=1.0
    (1, hello java test num) --> prob=[0.48019580016300345,0.5198041998369967], prediction=1.0
    (2, test hello scala) --> prob=[0.6270035488150222,0.3729964511849778], prediction=0.0 //这条分析错误,样本数据不够,或者样本干扰
    (3, j hello spark) --> prob=[0.031182836719302286,0.9688171632806978], prediction=1.0
    (4, abc hello c#) --> prob=[0.006011466954209337,0.9939885330457907], prediction=1.0
    (5, hell java spark) --> prob=[0.9210765571223096,0.07892344287769032], prediction=0.0
    (6, hello java spark) --> prob=[0.1785326777978406,0.8214673222021593], prediction=1.0
    (7, num he he java spark) --> prob=[0.6923088930430097,0.30769110695699026], prediction=0.0
    (8, hello2 java spark) --> prob=[0.9016001424620457,0.09839985753795444], prediction=0.0
    (9, hello do some thing java spark) --> prob=[0.1785326777978406,0.8214673222021593], prediction=1.0
    (10, world hello java spark) --> prob=[0.05144953292014106,0.9485504670798589], prediction=1.0
    */
    //probability 是预测概率向量,第一个值是不符合度,第二个值是符合度,
    //prediction的标签取决于模型的阀值设置严格度
     
  • 相关阅读:
    jQuery Mobile 总结
    妙味,结构化模块化 整站开发my100du
    详解使用icomoon生成字体图标的方法并应用
    Vue.js搭建路由报错 router.map is not a function,Cannot read property ‘component’ of undefined
    jquery 最全知识点图示
    图解Js event对象offsetX, clientX, pageX, screenX, layerX, x区别
    Oracle存储过程及函数的练习题
    SQL中IS NOT NULL与!=NULL的区别
    mysql字符集和排序规则
    一个web项目web.xml的配置中<context-param>配置作用
  • 原文地址:https://www.cnblogs.com/asker009/p/12145408.html
Copyright © 2011-2022 走看看