zoukankan      html  css  js  c++  java
  • Spark ML逻辑回归

     1 import org.apache.log4j.{Level, Logger}
     2 import org.apache.spark.ml.classification.LogisticRegression
     3 import org.apache.spark.ml.linalg.Vectors
     4 import org.apache.spark.sql.SparkSession
     5 
     6 /**
     7   * 逻辑回归
     8   * Created by zhen on 2018/11/20.
     9   */
    10 object LogisticRegression {
    11   Logger.getLogger("org").setLevel(Level.WARN) // 设置日志级别
    12   def main(args: Array[String]) {
    13     val spark = SparkSession.builder()
    14       .appName("LogisticRegression")
    15       .master("local[2]")
    16       .getOrCreate()
    17     val sqlContext = spark.sqlContext
    18     // 加载训练数据和测试数据
    19     val data = sqlContext.createDataFrame(Seq(
    20       (1.0, Vectors.dense(0.0, 1.1, 0.1)),
    21       (0.0, Vectors.dense(2.0, 1.0, -1.1)),
    22       (1.0, Vectors.dense(1.0, 2.1, 0.1)),
    23       (0.0, Vectors.dense(2.0, -1.3, 1.1)),
    24       (0.0, Vectors.dense(2.0, 1.0, -1.1)),
    25       (1.0, Vectors.dense(1.0, 2.1, 0.1)),
    26       (1.0, Vectors.dense(2.0, 1.3, 1.1)),
    27       (0.0, Vectors.dense(-2.0, 1.0, -1.1)),
    28       (1.0, Vectors.dense(1.0, 2.1, 0.1)),
    29       (0.0, Vectors.dense(2.0, -1.3, 1.1)),
    30       (1.0, Vectors.dense(2.0, 1.0, -1.1)),
    31       (1.0, Vectors.dense(1.0, 2.1, 0.1)),
    32       (0.0, Vectors.dense(-2.0, 1.3, 1.1)),
    33       (1.0, Vectors.dense(0.0, 1.2, -0.4))
    34     ))
    35     .toDF("label", "features")
    36     val weights = Array(0.8,0.2) //设置训练集和测试集的比例
    37     val split_data = data.randomSplit(weights) // 拆分训练集和测试集
    38     // 创建逻辑回归对象
    39     val lr = new LogisticRegression()
    40     // 设置参数
    41     lr.setMaxIter(10).setRegParam(0.01)
    42     // 训练模型
    43     val model = lr.fit(split_data(0))
    44     model.transform(split_data(1))
    45     .select("label", "features", "probability", "prediction")
    46     .collect()
    47     .foreach(println(_))
    48     //关闭spark
    49     spark.stop()
    50   }
    51 }

    结果:

  • 相关阅读:
    常见的HTTP状态码(HTTP Status Code)说明
    Java基本数据类型和Integer缓存机制
    面向对象的三大基本特征和五大基本原则
    工程变更(ENGINEERING CHANGE)
    反射
    Redis学习手册(开篇)
    MVC,SSM与三层架构的构成及相互关系
    Java框架篇---Mybatis 入门
    java三大框架介绍
    WEB前端JS与UI框架
  • 原文地址:https://www.cnblogs.com/yszd/p/9988597.html
Copyright © 2011-2022 走看看