zoukankan      html  css  js  c++  java
  • Spark SVM分类器

    package Spark_MLlib
    
    import java.util.Properties
    
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.mllib.linalg.{Vector, Vectors}
    import org.apache.spark.mllib.classification.SVMWithSGD
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.Row
    /**
    * SGD:一种优化方法,目的使目标函数的值最小.SGDstochastic gradient descent,即随机梯度下降.是梯度下降的batch版本.
    * 这么做的好处在于:
    * 当训练数据太多时,利用整个数据集更新往往时间上不显示.batch的方法可以减少机器的压力,并且可以更快地收敛.
    * 当训练集有很多冗余时(类似的样本出现多次),batch方法收敛更快.以一个极端情况为例,若训练集前一半和后一半梯度相同.那么如果前一半作为一个batch,
    * 后一半作为另一个batch,那么在一次遍历训练集时,batch的方法向最优解前进两个step,而整体的方法只前进一个step.
    */ object SVM{ def main(args: Array[String]): Unit = { val spark=SparkSession.builder().master("local").getOrCreate() val train_data=spark.sparkContext.textFile("file:///home/soyo/下载/Hadoop+Spark+Hbase/Spark训练数据/train_after.csv") val test_data=spark.sparkContext.textFile("file:///home/soyo/下载/Hadoop+Spark+Hbase/Spark训练数据/test_after.csv") val train=train_data.map{x=> val parts=x.split(",") LabeledPoint(parts(4).toDouble,Vectors.dense(parts(0).toDouble,parts(1).toDouble,parts(2).toDouble,parts(3).toDouble)) } // train.foreach(println) val test=test_data.map{x=> val y=x.split(",") LabeledPoint(y(4).toDouble,Vectors.dense(y(0).toDouble,y(1).toDouble,y(2).toDouble,y(3).toDouble)) } //numIterations:迭代次数,默认是100 val numIterations = 600 val model=SVMWithSGD.train(train,numIterations) //val model=new SVMWithSGD().run(train) //两种求model都行 //清除默认阈值,这样会输出原始的预测评分,即带有确信度的结果。 model.clearThreshold() val scoreAndLabels=test.map{x=> val score=model.predict(x.features) score+" "+x.label } scoreAndLabels.foreach(println) val rebuyRDD=scoreAndLabels.map(_.split(" ")) //设置模式信息 val schema=StructType(List(StructField("score",StringType,true),StructField("label",StringType,true))) //创建Row对象,每个Row对象都是rowRDD中的一行 val rowRDD=rebuyRDD.map(x=>Row(x(0).trim,x(1).trim)) //建立模式和数据之间的关系 val rebuyDF=spark.createDataFrame(rowRDD,schema) //prop变量保存JDBC连接参数 val prop=new Properties() prop.put("user","root") prop.put("password","密码") prop.put("driver","com.mysql.jdbc.Driver") //表示驱动程序是com.mysql.jdbc.Driver rebuyDF.write.mode("append").jdbc("jdbc:mysql://localhost:3306/数据库名","数据库表名",prop) } }

    Spark 机器学习库从 1.2 版本以后被分为两个包:

    
    
    • spark.mllib 包含基于RDD的原始算法API。Spark MLlib 历史比较长,在1.0 以前的版本即已经包含了,提供的算法实现都是基于原始的RDD。
    • spark.ml 则提供了基于DataFrames 高层次的API,可以用来构建机器学习工作流(PipeLine).ML Pipeline 弥补了原始 MLlib 库的不足,向用户提供了一个基于 DataFrame 的机器学习工作流式 API 套件。
    从Spark2.0开始,Spark机器学习库基于RDD的API进入维护模式(即不增加任何新的特性),很有可能于3.0以后的版本的时候会移除出MLLib

  • 相关阅读:
    jquery.autocomplete 使用解析
    《SEO实战密码》
    Thinkphp 生成的验证码不显示问题解决
    css去除li的小圆点
    css隐藏input边框阴影
    HBuilde 申请密钥证书
    请求筛选模块被配置为拒绝包含 hiddenSegment 节的 URL 中的路径
    js 判断屏幕下拉上滑操作
    gis 从WGS84转百度
    GIS个坐标系转换
  • 原文地址:https://www.cnblogs.com/soyo/p/8011276.html
Copyright © 2011-2022 走看看