zoukankan      html  css  js  c++  java
  • Spark实现tf-idf

    scala代码:

    package offline
    
    import org.apache.spark.ml.feature.{HashingTF, IDF}
    import org.apache.spark.ml.linalg.Vectors
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.functions._
    
    object TfIdfTransform {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession
          .builder()
          .appName("tf-idf")
          .enableHiveSupport()
          .getOrCreate()
    //    allfiles.txt 所有的文本已经放大一个文件中
        val df = spark.sql("select sentence from badou.news_seg")
        val df_seg = df.selectExpr("split(split(sentence,'##@@##')[0],' ') as seg")
        val doc_size = df_seg.count()
    //    spark中自带的tf-idf方法
    //    hashingtf 把文章[word1,word2,...] => sparseVector(2^18,[word hash code(index)][word count])
        val hashingtf = new HashingTF().setBinary(false)
          .setInputCol("seg").setOutputCol("feature_tf")
          .setNumFeatures(1<<18)  //26,2144
    //    val hashingtf_bn = new HashingTF().setBinary(true)
    //      .setInputCol("seg").setOutputCol("feature_tf")
    //      .setNumFeatures(1<<18)
    //
    //    val df_tf_bn = hashingtf_bn.transform(df_seg).select("feature_tf")
        val df_tf = hashingtf.transform(df_seg).select("feature_tf")
    
    //    idf 对word进行idf加权
        val idf = new IDF().setInputCol("feature_tf").setOutputCol("feature_tfidf")
          .setMinDocFreq(2)
    
        val idfModel =idf.fit(df_tf)
        val df_tfIdf = idfModel.transform(df_tf).select("feature_tfidf")
    
    //    自己实现tf-idf
    //    1. doc Freq 文档频率计算 -> 所有文章的单词集合(词典)
        val setUDF = udf((str:String)=>str.split(" ").distinct)
        val df_set = df.withColumn("words_set",setUDF(col("sentence")))
        val docFreq_map = df_set.select(explode(col("words_set")).as("word"))
          .groupBy("word").count().rdd.map(x=>(x(0).toString,x(1).toString))
          .collectAsMap()
        val wordEncode = docFreq_map.keys.zipWithIndex.toMap  // [0-42362]
        val dictSize = docFreq_map.size
    //    共有4,2363
    //    docFreq.count()
    //    2. term Freq 词频计算 对每篇文章(一行数据)统计词频
        val mapUDF = udf{(str:String)=>
          val tfMap = str.split("##@@##")(0).split(" ")
            .map((_,1L)).groupBy(_._1).mapValues(_.length)
    
          val tfIDFMap = tfMap.map{x=>
            val idf_v = math.log10(doc_size.toDouble/(docFreq_map.getOrElse(x._1,"0.0").toDouble+1.0))
            (wordEncode.getOrElse(x._1,0),x._2.toDouble * idf_v)
          }
    
          Vectors.sparse(dictSize,tfIDFMap.toSeq)
        }
        val dfTF = df.withColumn("tf_idf",mapUDF(col("sentence")))
      }
    
    }
  • 相关阅读:
    Python 3 Mysql 增删改查
    Python3 MySQL 数据库连接 -PyMySQL
    java 获取cookie
    Python 通过配置文件 读取参数,执行测试用例,生成测试报告并发送邮件
    Python 操作 Excel 、txt等文件
    SonarQube代码质量管理平台安装与使用
    Python + HTMLTestRunner + smtplib 完成测试报告生成及发送测试报告邮件
    Python 解析Xml文件
    GO语言基础
    FileBeat
  • 原文地址:https://www.cnblogs.com/xumaomao/p/12763375.html
Copyright © 2011-2022 走看看