zoukankan      html  css  js  c++  java
  • scala spark2.0 rdd dataframe 分布式计算欧式距离

    1、配置文件

    package config
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.{SparkConf, SparkContext}
    case object conf {
       private val master = "local[*]"
       val confs: SparkConf = new SparkConf().setMaster(master).setAppName("jobs")
    //   val confs: SparkConf = new SparkConf().setMaster("http://laptop-2up1s8pr:4040/").setAppName("jobs")
       val sc = new SparkContext(confs)
       sc.setLogLevel("ERROR")
       val spark_session: SparkSession = SparkSession.builder()
        .appName("jobs").config(confs).getOrCreate()
    
    //   设置支持笛卡尔积 对于spark2.0来说
       spark_session.conf.set("spark.sql.crossJoin.enabled",true)
    }
    

      

    2、欧式距离计算

    package classifierAlg
    import config.conf.{sc, spark_session}
    import config.conf.spark_session.implicits._
    import org.apache.spark.sql.functions._
    import breeze.linalg._
    import breeze.numerics._
    import org.apache.log4j.{Level, Logger}
    import org.apache.parquet.schema.Types.ListBuilder
    import org.apache.spark.rdd.RDD
    import org.apache.spark.sql.types.{StructField, _}
    import org.apache.spark.sql.{Column, DataFrame, Row}
    
    import scala.collection.mutable.ListBuffer
    object euclideanDist {
    
      def testFeature(): DataFrame ={
        val someData: Seq[Row] = Seq(Row(5.1, 3.8, 1.6, 0.2, 1.0))
    
        val someSchema = List(
          StructField("dt1", DoubleType, true),
          StructField("dt2", DoubleType, true),
          StructField("dt3", DoubleType, true),
          StructField("dt4", DoubleType, true),
          StructField("dt5", DoubleType, true)
        )
    
    
        val dt1: DataFrame = spark_session.createDataFrame(
          sc.parallelize(someData),
          StructType(someSchema)
        )
        dt1
    
      }
    
      def euclideanDisting(): Unit ={
        //矩阵及向量操作:https://blog.csdn.net/hellozhxy/article/details/82852561
    
        val path:String = "data/irsdf/part-00000-ca2d6ce7-bcd0-4c24-aba9-e8cb01dcc04c-c000.csv"
        val df: DataFrame = spark_session.read.csv(path).toDF("ft1","ft2","ft3","ft4","label")
        val cols: Array[String] = df.columns
        val n: Int = cols.length
        //     val colNames : Array[String] = df.schema.fieldNames
        val colsd: Array[Column] = cols.map(f => df(f).cast(DoubleType))
        val df2: DataFrame = df.select(colsd: _*)
        println(df2.count())
    
    
        //测试特征量
        val tfter:Array[Double] =Array(5.1,3.5,1.4,0.2,1.0)
    
        val nl:Int = tfter.length
    
        //计算欧式距离
        val df3: RDD[Double] = df2.rdd.map(row => {
          var sg: Double = 0D
          for (i <- 0 until nl) {
            sg += math.pow(row(i).toString.toDouble - tfter(i).toString.toDouble, 2)
          }
          math.sqrt(sg)
        })
        df3.toDF("dist").where("dist < 0.2").show()
    //    df3.toDF("dist")show()
        
    
      }
    
      def main(args: Array[String]): Unit = {
    
        /*
        val fs1: (Double, Double, Double, Double, Double) = (5.1, 3.8, 1.6, 0.2, 1.0)
        //    df2.printSchema()
    
        val f1: DenseVector[Double] = DenseMatrix(fs1).toDenseVector
    
        println(f1)
        * */
        euclideanDisting()
    
    //    val d1: DenseVector[Double] = DenseVector(5.1, 3.5, 1.4, 0.2, 1.0)
        //0.5385164807134502
    
    //    val d2: DenseVector[Double] = DenseVector(4.9, 3.0, 1.4, 0.2, 1.0)
    
    
    //    println(math.sqrt((((d1-d2)*(d1-d2)).sum)))
    
    
    
      }
    }
    

      

  • 相关阅读:
    .Net魔法堂:log4net详解
    CentOS6.5菜鸟之旅:安装SUN JDK1.7和Tomcat7
    Java魔法堂:注释和注释模板
    Eclipse魔法堂:任务管理器
    CentOS6.5菜鸟之旅:VirtualBox4.3识别USB设备
    Windows魔法堂:解决“由于启动计算机时出现页面文件配置问题.......”
    JS魔法堂:IE5~9的Drag&Drop API
    CentOS6.5菜鸟之旅:安装VirtualBox4.3
    HTML5魔法堂:全面理解Drag & Drop API
    byzx
  • 原文地址:https://www.cnblogs.com/wuzaipei/p/12627335.html
Copyright © 2011-2022 走看看