zoukankan      html  css  js  c++  java
  • 推荐系统-01-简单逻辑回归

    import org.apache.spark.ml.feature._
    import org.apache.spark.ml.param.ParamMap
    import org.apache.spark.ml.classification.LogisticRegression
    import org.apache.spark.ml.{Pipeline,PipelineModel}
    import org.apache.spark.ml.linalg.{Vector, Vectors}
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.SparkSession
    
    
    object BasicStastic{
    def main(args: Array[String]) {
      
    
    val spark = SparkSession.builder().
                master("local").
                appName("my App Name").
                getOrCreate()
                
    // 创建数据帧(id, 内容,标签)
    val training = spark.createDataFrame(Seq(
                (0.0, Vectors.dense(2.0, 1.1, 0.1)),
                (1.0, Vectors.dense(0.0, 1.0, -1.0)),
                (2.0, Vectors.dense(0.0, 1.3, 1.0)),
                (3.0, Vectors.dense(2.0, 1.2, -0.5))
            )).toDF("label", "features")            
    		
    // 相关参数,可以在官方文档有介绍
    // http://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.ml.classification.LogisticRegression		
    val lr = new LogisticRegression().
                setMaxIter(10).
                setRegParam(0.01)		
    		
    val model1 = lr.fit(training)			
    model1.parent.extractParamMap
    
    val paramMap = ParamMap(lr.maxIter -> 20).put(lr.regParam -> 0.1, lr.threshold -> 0.55)
    val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability")
    val paramMapCombined = paramMap ++ paramMap2
    
    val model2 = lr.fit(training, paramMapCombined)			
    model2.parent.extractParamMap
    
    // 测试数据
    val test = spark.createDataFrame(Seq(
                (3.0, Vectors.dense(-1.0, 1.5, 1.3)),
                (0.0, Vectors.dense(3.0, 2.0, -0.1)),
                (1.0, Vectors.dense(0.0, 2.2, -1.5))            
            )).toDF("label", "features")            
    		
     val result = model1.transform(test)
     // 显示结果
     result.show(false)
     result.select("label", "features", "probability", "prediction").show(false)
     result.select("label", "features", "probability", "prediction").collect().foreach{case Row(label:Double, features:Vector, probability:Vector, prediction:Double) => println(s"($features, $label) ->  probability=$probability, prediction=$prediction")}
     
     }}
    
    
  • 相关阅读:
    Jetty和tomcat的比较
    Spring Boot – Jetty配置
    Java规则之条件语句中做空判断时使用||和&&常犯的错误
    bboss oreach循环嵌套遍历map
    url全部信息打印
    ajax省市县三级联动
    关于mysql中的count()函数
    vue——统一配置axios的baseUrl和所有请求的路径
    js——substr与substring的区别
    vue——axios请求成功却进入catch的原因
  • 原文地址:https://www.cnblogs.com/freebird92/p/9026311.html
Copyright © 2011-2022 走看看