zoukankan      html  css  js  c++  java
  • spark 线性回归算法(scala)

    构建Maven项目,托管jar包

    数据格式

    //0.fp_nid,1.nsr_id,2.gf_id,2.hydm,3.djzclx_dm,4.kydjrq,5.xgrq,6.je,7.se,8.jshj,9.kpyf,10.kprq,11.zfbz,12.date_key,13.hwmc,14.ggxh,15.dw,16.sl,17.dj,18.je je1,19.se1,20.spbm,21.label

    (fpid_10000201 115717 (2239 173 2011-07-12 00:00:00.0 2016-08-31 15:40:37.0 4123.08 700.92 4824.0 201704 2017-04-25 N) 201706 可视回油单向阀 HYS-1Φ1.5A 只 3.0 35.8974358974359 107.69 18.31 1090120040000000000) 0)
    (fpid_10000324 253389 (7310 173 2016-01-04 00:00:00.0 2017-07-24 10:01:02.0 36609.76 6223.64 42833.4 201709 2017-09-08 N) 201711 电视机 三星743寸 台 1.0 2991.4529914529912 2991.45 508.55 1090522010000000000) 0)
    (fpid_10000416 126378 (5175 173 1999-01-14 00:00:00.0 2016-05-27 14:50:49.0 25337.81 4307.39 29645.2 201612 2016-12-21 N) 201706 防水涂料 null 公斤 105.0 5.225885225885226 548.72 93.28 1070101060000000000) 0)

    package Test.tett1
    
    import org.apache.spark.ml.classification.LogisticRegression
    import org.apache.spark.ml.param.ParamMap
    import org.apache.spark.ml.linalg.{Vector, Vectors}
    import org.apache.spark.ml.regression.LinearRegressionModel
    import org.apache.spark.sql.{Row, SparkSession}
    import org.apache.spark.ml.regression.LinearRegression
    import org.apache.spark.ml.regression.LinearRegression
    
    object MLDemo3 {
      
      def main(args: Array[String]): Unit = {
                    val sess = SparkSession.builder().appName("ml").master("local[4]").getOrCreate();
                    val sc = sess.sparkContext;
                    val dataDir = "hdfs://weekend110:9000/user/hive/warehouse/nsr2_xfp"
                    //定义样例类(要分析数据的类属性)
            case class FP(fp_nid:String,nsr_id:String,gf_id:String,hydm:String,djzclx_dm:String,kydjrq:String,xgrq:String,
                        je:String,se:String,jshj:String,kpyf:String,kprq:String,zfbz:String,
                        label:String)
    
                    //变换()
                   //0.fp_nid,1.nsr_id,2.gf_id,2.hydm,3.djzclx_dm,4.kydjrq,5.xgrq,6.je,7.se,8.jshj,9.kpyf,10.kprq,11.zfbz,12.date_key,13.hwmc,14.ggxh,15.dw,16.sl,17.dj,18.je je1,19.se1,20.spbm,21.label
                  val fpDataRDD = sc.textFile(dataDir).map(_.split("01")).map(f => FP(f(0).toString, 
                  f(1).toString,f(2).toString,f(3).toString,f(4).toString,f(5).toString,f(6).toString, 
                    f(7).toString, f(8).toString,f(9).toString,f(10).toString,f(11).toString,f(12).toString,
                    f(13).toString))
    
                    import sess.implicits._
                
                    def strToDouble(str: String): Double = {
              val regex = """([0-9]+)""".r
              val res = str match{
                case regex(num) => num
                case _ => "1"
              }
              val resDouble = res.toDouble
              resDouble
            }
                    
                    //转换RDD成DataFrame
                    //1.fp_nid 2.nsr_id 3.gf_id 4.zfbz 5.hydm 6.djzclx_dm 7.je 8.se 9.jshj 10.kpyf 11.date_key 12.sl 13.dj 14.je1 15.se1 16.spbm
                    val trainingDF = fpDataRDD.map(f => (f.label.replaceAll("[)]","").toDouble,
                        Vectors.dense( 
                        if(f.zfbz.equals("N)")) 1 else 0,
                        f.hydm.replaceAll("[(]","").toDouble,
                        f.djzclx_dm.toDouble,
                        f.kpyf.toDouble,
                        strToDouble(f.je),
                strToDouble(f.se),
                strToDouble(f.jshj)
                ))).toDF("label", "features")    
                            
                    //显式数据
                    trainingDF.show()
                    println("======================")
    
                    //创建线性回归对象
                    val lr = new LinearRegression()
                    //设置最大迭代次数
                    lr.setMaxIter(50)
                    //通过线性回归拟合训练数据,生成模型
                    val model = lr.fit(trainingDF)
    
                    //创建内存测试数据数据框
              val testDF = sess.createDataFrame(Seq(
                        (0,Vectors.dense(3812,171,9401.71,1598.29,11000.0,201612,1)),
                        (0,Vectors.dense(4190,173,72200.0,12274.0,84474.0,201710,1)),
                        (0,Vectors.dense(7519,173,99999.99,3000.0,102999.99,201709,1)),
                        
                        (1,Vectors.dense(1951,173,19743.59,3356.41,23100.0,201612,1)),
                        (1,Vectors.dense(5219,173,41880.35,7119.65,49000.0,201705,1)),
                        (1,Vectors.dense(5189,173,1320.93,224.56,1545.49,201611,1)),    
                        (1,Vectors.dense(1779,173,21911.4,3724.94,25636.34,201611,0))
                    )).toDF("label", "features")
                    
                    testDF.show()
    
                    //创建临时视图
                    testDF.createOrReplaceTempView("test")
                    println("======================")
                    
                    //利用model对测试数据进行变化,得到新数据框,查询features", "label", "prediction方面值。        
                    val tested = model.transform(trainingDF).select("features", "label", "prediction");
                    tested.show();
                
                    //将分析的数据导入数据库                
                    import java.sql.DriverManager
                  tested.rdd.foreachPartition(
                    it =>{
                          var url = "jdbc:mysql://localhost:3306/data?useUnicode=true&characterEncoding=utf8"
                          val conn= DriverManager.getConnection(url,"root","123456")
                          val pstat = conn.prepareStatement ("INSERT INTO `test` (`label`, `pre`,`zfbz`,`hydm`, `djzclx_dm`, "
                                                            +"`kpyf`,`je`,`se`,`jshj`) "
                                                            +"VALUES (?,?,?,?,?,?,?,?,?)")
                          for (obj <-it){
                              pstat.setString(1,obj.get(1).toString())
                              pstat.setString(2,obj.get(2).toString())
                              pstat.setString(3,obj.get(0).toString().split(",")(0).replaceAll("[\[]", ""))
                              pstat.setString(4,obj.get(0).toString().split(",")(1))
                              pstat.setString(5,obj.get(0).toString().split(",")(2))
                              pstat.setString(6,obj.get(0).toString().split(",")(3))
                              pstat.setString(7,obj.get(0).toString().split(",")(4))
                              pstat.setString(8,obj.get(0).toString().split(",")(5))
                              pstat.setString(9,obj.get(0).toString().split(",")(6) .replaceAll("[\]]", ""))
                              pstat.addBatch
                          }
                          try{
                              pstat.executeBatch
                          }finally{
                              pstat.close
                              conn.close
                          }
                     }
                )    
                }
    }

    maven的pom.xml配置文件

    <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
      <modelVersion>4.0.0</modelVersion>
      <groupId>Test</groupId>
      <artifactId>tett1</artifactId>
      <version>0.0.1-SNAPSHOT</version>
      <inceptionYear>2008</inceptionYear>
      <properties>
        <scala.version>2.7.0</scala.version>
      </properties>
    
     <repositories>
        <repository>
          <id>scala-tools.org</id>
          <name>Scala-Tools Maven2 Repository</name>
          <url>http://scala-tools.org/repo-releases</url>
        </repository>
      </repositories>
    
      <pluginRepositories>
        <pluginRepository>
          <id>scala-tools.org</id>
          <name>Scala-Tools Maven2 Repository</name>
          <url>http://scala-tools.org/repo-releases</url>
        </pluginRepository>
      </pluginRepositories>
    
      <dependencies>
       <!--  <dependency>
          <groupId>org.scala-lang</groupId>
          <artifactId>scala-library</artifactId>
          <version>${scala.version}</version>
        </dependency> -->
         <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-mllib_2.11</artifactId>
                <version>2.1.0</version>
         </dependency>
      </dependencies>
    
      <build>
        <sourceDirectory>src/main/scala</sourceDirectory>
        <testSourceDirectory>src/test/scala</testSourceDirectory>
        <pluginManagement>
        <plugins>
           <plugin>
             <groupId>org.apache.maven.plugins</groupId>
              <artifactId>maven-surefire-plugin</artifactId>
              <configuration>
              <skip>true</skip>
             </configuration>
           </plugin> 
        
          <plugin>
            <groupId>org.scala-tools</groupId>
            <artifactId>maven-scala-plugin</artifactId>
            <executions>
              <execution>
                <goals>
                  <goal>compile</goal>
                  <goal>testCompile</goal>
                </goals>
              </execution>
            </executions>
            <configuration>
              <scalaVersion>${scala.version}</scalaVersion>
              <args>
                <arg>-target:jvm-1.5</arg>
              </args>
            </configuration>
          </plugin>
          <plugin>
            <groupId>org.apache.maven.plugins</groupId>
            <artifactId>maven-eclipse-plugin</artifactId>
            <configuration>
              <downloadSources>true</downloadSources>
              <buildcommands>
                <buildcommand>ch.epfl.lamp.sdt.core.scalabuilder</buildcommand>
              </buildcommands>
              <additionalProjectnatures>
                <projectnature>ch.epfl.lamp.sdt.core.scalanature</projectnature>
              </additionalProjectnatures>
              <classpathContainers>
                <classpathContainer>org.eclipse.jdt.launching.JRE_CONTAINER</classpathContainer>
                <classpathContainer>ch.epfl.lamp.sdt.launching.SCALA_CONTAINER</classpathContainer>
              </classpathContainers>
            </configuration>
          </plugin>
        </plugins>
        </pluginManagement>
      </build>
      <reporting>
        <plugins>
          <plugin>
            <groupId>org.scala-tools</groupId>
            <artifactId>maven-scala-plugin</artifactId>
            <configuration>
              <scalaVersion>${scala.version}</scalaVersion>
            </configuration>
          </plugin>
        </plugins>
      </reporting>
    </project>
  • 相关阅读:
    Spring_AOP动态代理详解(转)
    Java中spring读取配置文件的几种方法
    SpringMVC工作原理2(代码详解)
    SpringMVC工作原理1(基础机制)
    Web服务器和应用服务器简介
    WEB服务器与应用服务器解疑
    WebService基本概念及原理
    HTTP协议
    TCP、UDP协议间的区别(转)
    HTTP、TCP、UDP以及SOCKET之间的区别/联系
  • 原文地址:https://www.cnblogs.com/Zhanghaonihao/p/9304667.html
Copyright © 2011-2022 走看看