zoukankan      html  css  js  c++  java
  • 掌握Spark机器学习库-09.6-LDA算法

    数据集

    iris.data

    数据集概览

    代码

    package org.apache.spark.examples.examplesforml
    
    import org.apache.spark.ml.clustering.{KMeans, LDA}
    import org.apache.spark.SparkConf
    import org.apache.spark.ml.feature.VectorAssembler
    import org.apache.spark.sql.SparkSession
    
    import scala.util.Random
    
    object lLDA {
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf().setMaster("local").setAppName("iris")
        val spark = SparkSession.builder().config(conf).getOrCreate()
    
        val file = spark.read.format("csv").load("D:\9-4LDA算法\iris.data")
        file.show()
    
        import spark.implicits._
        val random = new Random()
        val data = file.map(row => {
          val label = row.getString(4) match {
            case "Iris-setosa" => 0
            case "Iris-versicolor" => 1
            case "Iris-virginica" => 2
          }
    
          (row.getString(0).toDouble,
            row.getString(1).toDouble,
            row.getString(2).toDouble,
            row.getString(3).toDouble,
            label,
            random.nextDouble())
        }).toDF("_c0", "_c1", "_c2", "_c3", "label", "rand").sort("rand")
        val assembler = new VectorAssembler()
          .setInputCols(Array("_c0", "_c1", "_c2", "_c3"))
          .setOutputCol("features")
    
        val dataset = assembler.transform(data)
        val Array(train, test) = dataset.randomSplit(Array(0.8, 0.2))
        train.show()
        /*
            val kmeans = new KMeans().setFeaturesCol("features").setK(3).setMaxIter(20)
            val model = kmeans.fit(train)
            model.transform(train).show()
            */
        val lda = new LDA().setFeaturesCol("features").setK(3).setMaxIter(40)
        val model = lda.fit(train)
        val prediction = model.transform(train)
        //prediction.show()
        val ll = model.logLikelihood(train)
        val lp = model.logPerplexity(train)
        // Describe topics.
        val topics = model.describeTopics(3)
        prediction.select("label","topicDistribution").show(false)
        println("The topics described by their top-weighted terms:")
        topics.show(false)
        println(s"The lower bound on the log likelihood of the entire corpus: $ll")
        println(s"The upper bound on perplexity: $lp")
      }
    }

    输出结果

  • 相关阅读:
    Laravel自定义分页样式
    mysql中 key 、primary key 、unique key 和 index 有什么不同
    PHP RSA公私钥的理解和示例说明
    PHP操作Excel – PHPExcel 基本用法
    Yii 1.1 常规框架部署和配置
    阿里云服务器 Ubuntu 安装 LNMP
    全国地区sql表
    十道海量数据处理面试题与十个方法大总结
    Hibernate中对象的三种状态以及Session类中saveOrUpdate方法与merge方法的区别
    乐观锁与悲观锁——解决并发问题
  • 原文地址:https://www.cnblogs.com/moonlightml/p/9789846.html
Copyright © 2011-2022 走看看