zoukankan      html  css  js  c++  java
  • 基于ml的spark中文文本分类(朴素贝叶斯)

    基于ml的spark中文文本分类(朴素贝叶斯)

    中文分词的流程和语料库的获取可以参考 https://www.cnblogs.com/DismalSnail/p/11801742.html
    这里展示一下spark新的机器学习包ml的使用,分词工具为HanLP(详见 https://github.com/hankcs/HanLP )词语权重为TF-IDF,分类器为朴素贝叶斯分类器,本次实验将复旦中文语料库的训练集与测试集合并为一个。


    package com.teligen.subject.ML
    
    import java.io.File
    
    import com.hankcs.hanlp.HanLP
    import org.apache.commons.io.FileUtils
    import org.apache.spark.ml.classification.NaiveBayes
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.feature.{HashingTF, IDF, IDFModel, Tokenizer}
    import org.apache.spark.sql.{DataFrame, SparkSession}
    
    import scala.collection.mutable.ListBuffer
    
    /**
     * 朴素贝叶斯训练示例
     */
    object NBClassDemo {
    
      //将分词后的词语转为间隔为空格的字符串
      def toStringList(termString: String): String = {
        termString.replace("[", "").replace("]", "").replace(",", "")
    
      }
    
      //存储代表标签的Double值和分词后的字符串
      //注意这里的Double必须从0.0开始,顺序增长 0.0 1.0 2.0 ... ,不然即使预测正确,标签的Double值也对不上,正确率的计算会
      //出错
      val labelAndSentenceSeq: ListBuffer[(Double, String)] = ListBuffer[(Double, String)]()
    
      //分词函数
      def segment(corpusPath: String): Unit = {
        //代表标签的Double,从0.0开始
        var count: Double = 0.0
        //设置hanLP分词结果不带词性,这样toString后就不会有 词性字符了,方便构建词向量
    
        HanLP.Config.ShowTermNature = false
        //打开根目录
        val corpusDir: File = new File(corpusPath)
        //类别目录
        for (classDir: File <- corpusDir.listFiles()) {
          //文件
          for (text <- classDir.listFiles()) {
            //将标签Double,和分词后的字符串存入labelAndSentenceSeq
            labelAndSentenceSeq.append(Tuple2(count,
              //对HanLP.segment().toString修改,使两个词之间为空格
              toStringList(
                //分词
                HanLP.segment(
                  //以字符串的形式读取文本
                  FileUtils.readFileToString(text)
                    .replace("
    ", "")//去换行、回车
                    .replace("
    ", "")//去回车
                    .replace("
    ", "")//去换行
                    .replace(" ", "")//去空格
                    .replace("u3000", "")//去全角空格(中文空格)
                    .replace("	", "")//去制表符
                    .replaceAll(s"\pP|\pS|\pC|\pN|\pZ", "")//通过Unicode的类别相关正则,去掉各种符号
                    .trim
                  //分类器的toSting,单词之间使逗号+空格,需要进一步处理
                ).toString)))
          }
          //改变标签Label
          count = count + 1.0
        }
      }
    
      //构建以TF-IDF为权重的词向量
      def tfIdf(spark: SparkSession): DataFrame = {
        //将标签Double和分词后的字符串转为DataFrame
        val sentenceData: DataFrame = spark.createDataFrame(labelAndSentenceSeq.toSeq).toDF("label", "sentence")
        
        //将字分词后的字符串分割为一个个词语,Tokenizer()只能分割以空格间隔的字符串,
        // RegexTokenizer功能更强大,详情可以点进Tokenizer()源码查看
        
        //新建sentence --> words分割器
        val tokenizer: Tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words")
        //进行分割
        //这里如果不select(),则每一步的计算结果都存储在DataFrame,导致DataFrame很大,很容易造成 java heap space 异常
        val wordsData: DataFrame = tokenizer.transform(sentenceData).select("label", "words")
        
        //新建 words --> rawFeatures HasingTF类
        val hashingTF: HashingTF = new HashingTF()
          .setInputCol("words").setOutputCol("rawFeatures")
        
        //执行计算,获得每个语句中每词语的词频即 TF(Term Frequency)
        val featurizedData: DataFrame = hashingTF.transform(wordsData).select("label", "rawFeatures")
        
        //新建rawFeatures --> features IDF类
        val idf: IDF = new IDF().setInputCol("rawFeatures").setOutputCol("features")
        //计算IDF (Inverse Document Frequency)
        val idfModel: IDFModel = idf.fit(featurizedData)
        //计算TF-IDF
        idfModel.transform(featurizedData).select("label", "features")
      }
    
      //训练和预测函数
      def trainAndPredict(ifIdfData: DataFrame) = {
        //按比例选取测试集和训练集
        val Array(trainingData, testData) = ifIdfData.randomSplit(Array(0.7, 0.3), seed = 1234L)
        //训练朴素贝叶斯分类器
        val model = new NaiveBayes().fit(trainingData)
        //预测
        val predictions = model.transform(testData)
        //展示测试结果,50条
        predictions.show(50)
    
        //测试结果评估
        val evaluator = new MulticlassClassificationEvaluator()
          .setLabelCol("label")
          .setPredictionCol("prediction")
          .setMetricName("accuracy")
        //测试结果准确率
        val accuracy = evaluator.evaluate(predictions)
        println(s"Test set accuracy = $accuracy")
      }
    
      def main(args: Array[String]): Unit = {
        //新建spark上下文
        val spark = SparkSession.builder().master("local[2]").appName("NBC").getOrCreate()
        //分词
        segment("./corpus/all_corpus/")
        //训练和预测
        trainAndPredict(tfIdf(spark))
      }
    }
    
    

  • 相关阅读:
    python
    爬虫
    python 自动登录
    day22 cookie session 中间件 Form
    day10进程、异步IO、
    python第五课
    day21
    day20 Django
    day 19
    day18
  • 原文地址:https://www.cnblogs.com/DismalSnail/p/11802281.html
Copyright © 2011-2022 走看看