zoukankan      html  css  js  c++  java
  • Spark ML 之 LR逻辑回归实现排序

    一、理论

    https://www.jianshu.com/p/114100d0517f

    https://www.imooc.com/article/46843

    二、代码

    1、准备数据

    2、数据分成 train和test进行测试:用train的数据训练(fit)出的model带入(transform)test数据

    验证label和predict的是否足够精确

    3、排序

    package com.njbdqn
    
    import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
    import org.apache.spark.ml.linalg.Vectors
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.functions._
    
    /**
     * 排序:LR
     */
    object LRtest {
      val positive = udf{
        (vc:String)=>{
          vc.replaceAll("\[|\]","").split(",")(1).toDouble
      }}
    
      def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().appName("app").master("local[*]").getOrCreate()
        val data = spark.createDataFrame(Seq(
          ("1","2",1.0, Vectors.dense(0.0, 1.1, 0.1)),
          ("1","2",0.0, Vectors.dense(2.0, 1.0, -1.1)),
          ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)),
          ("1","2",0.0, Vectors.dense(2.0, -1.3, 1.1)),
          ("1","2",0.0, Vectors.dense(2.0, 1.0, -1.1)),
          ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)),
          ("1","2",1.0, Vectors.dense(2.0, 1.3, 1.1)),
          ("1","2",0.0, Vectors.dense(-2.0, 1.0, -1.1)),
          ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)),
          ("1","2",0.0, Vectors.dense(2.0, -1.3, 1.1)),
          ("1","2",1.0, Vectors.dense(2.0, 1.0, -1.1)),
          ("1","2",1.0, Vectors.dense(1.0, 2.1, 0.1)),
          ("1","2",0.0, Vectors.dense(-2.0, 1.3, 1.1)),
          ("1","2",1.0, Vectors.dense(0.0, 1.2, -0.4))
        )).toDF("user","goods","label","features")
          //.show(false)
        val Array(train,test) = data.randomSplit(Array(0.7,0.3))
        // 设置训练模型的超参
        val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01)
        // 训练模型
        val model = lr.fit(train)
        // 把模型存到HDFS
      //  model.save("hdfs://192.168.56.111:9000/LRmodel")
        // 获取HDFS上的模型
       val model2 = LogisticRegressionModel.load("hdfs://192.168.56.111:9000/LRmodel")
        // 检测模型的准确性
    //    val preRes = model.transform(test)
    //    preRes.show(false)
        val res = model2.transform(data)
        import spark.implicits._
    // 方法一:死办法,不推荐    
    // probability:[xxx,xxx],后面的数据是感兴趣的程度,超过0.5则predict为1
        res.withColumn("pro",$"probability".cast("String"))
            .select($"user",$"goods",positive($"pro").alias("score"))
            .orderBy(desc("score")).show(false)
    // 方法二:推荐,模式匹配方法
        res.select("user","goods","probability")
          .rdd.map{case(Row(uid:Double,gid:Double,score:DenseVector))=>(uid,gid,score(1))}
          .toDF("user","goods","score")
          .select($"user",$"goods",row_number().over(wnd).alias("rank"))
          .show(false)
    spark.stop()
      }
    }

    结果:

    +----+-----+-------------------+
    |user|goods|score |
    +----+-----+-------------------+
    |1 |2 |0.9473385564891683 |
    |1 |2 |0.9473385564891683 |
    |1 |2 |0.9473385564891683 |
    |1 |2 |0.9473385564891683 |
    |1 |2 |0.9202855138287962 |
    |1 |2 |0.5337766179253915 |
    |1 |2 |0.5337766179253915 |
    |1 |2 |0.5337766179253915 |
    |1 |2 |0.5081492680443979 |
    |1 |2 |0.5014483932183084 |
    |1 |2 |0.4713578993198038 |
    |1 |2 |0.09069927610736443|
    |1 |2 |0.03241657419240436|
    |1 |2 |0.03241657419240436|
    +----+-----+-------------------+

  • 相关阅读:
    bzoj 4606: [Apio2008]DNA【dp】
    UOJ #206. 【APIO2016】Gap【交互题】
    bzoj 4071: [Apio2015]巴邻旁之桥【splay】
    bzoj 4069: [Apio2015]巴厘岛的雕塑【dp】
    bzoj 4070: [Apio2015]雅加达的摩天楼【spfa】
    洛谷 P3625 [APIO2009]采油区域【枚举】
    bzoj 1178: [Apio2009]CONVENTION会议中心(少见做法掉落!)【贪心+二分】
    bzoj 1179: [Apio2009]Atm【tarjan+spfa】
    洛谷 P3621 [APIO2007]风铃【贪心】
    bzoj 4898: [Apio2017]商旅【Floyd+分数规划+二分】
  • 原文地址:https://www.cnblogs.com/sabertobih/p/13874338.html
Copyright © 2011-2022 走看看