zoukankan      html  css  js  c++  java
  • spark累加器及UDTF

    数据源test.json

    {"username": "zhangsan","age": 20}
    {"username": "lisi","age": 18}
    {"username": "wangwu","age": 16}

     强类型及弱类型

    DateFrame是弱类型,意思是数据只有二维结构,没有数据类型
    DateSet是强类型,数据包含数据结构,数据的二维结构和当前类之间有一个映射关系

    1.自定义累加器

    代码

    import org.apache.spark.util.AccumulatorV2
    import org.apache.spark.{SparkConf, SparkContext}
    
    object Spark01_TestSer {
      def main(args: Array[String]): Unit = {
        //1.创建SparkConf并设置App名称
        val conf: SparkConf = new SparkConf().setAppName("SparkCoreTest").setMaster("local[*]")
    
        //2.创建SparkContext,该对象是提交Spark App的入口
        val sc: SparkContext = new SparkContext(conf)
    
        var sumAc = new MyAccumulator
        sc.register(sumAc)
        sc.makeRDD(List(("zhangsan",20),("lisi",30),("wangw",40))).foreach{
          case (name,age)=>{
            sumAc.add(age)
          }
        }
        println(sumAc.value)
    
        // 关闭连接
        sc.stop()
      }
    }
    
    class MyAC extends AccumulatorV2[Int,Double] {
      var sum = 0
      var count = 0
    
      override def isZero: Boolean = {
        sum==0 && count==0
      }
    
      override def copy(): AccumulatorV2[Int,Double] = {
        val myAC = new MyAC()
        myAC.sum = this.sum
        myAC.count = this.count
        myAC
      }
    
      override def reset(): Unit = {
        sum = 0
        count = 0
      }
    
      override def add(v: Int): Unit = {
        sum += v
        count += 1
      }
    
      override def merge(other: AccumulatorV2[Int, Double]): Unit = {
        other match {
          case myAC: MyAC=>{
            this.sum += myAC.sum
            this.count += myAC.count
          }
          case _ =>
        }
      }
    
      override def value: Double = {
        sum/count
      }
    }
    
    class MyAccumulator extends AccumulatorV2[Int,Double] {
      var sum = 0
      var count = 0
    
      override def isZero: Boolean = {
        sum==0 && count==0
      }
    
      override def copy(): AccumulatorV2[Int, Double] = {
        val myAccumulator = new MyAccumulator
        myAccumulator.sum = this.sum
        myAccumulator.count = this.count
    
        myAccumulator
      }
    
      override def reset(): Unit = {
        sum = 0
        count = 0
      }
    
      override def add(v: Int): Unit = {
        sum += v
        count += 1
      }
    
      override def merge(other: AccumulatorV2[Int, Double]): Unit = {
        other match {
          case o:MyAccumulator =>
            sum += o.sum
            count += o.count
          case _ =>
        }
      }
    
      override def value: Double = {
        sum/count
      }
    }
    

    2.自定义UDTF(弱类型,DateFrame)

    代码

    /*
    *
    * 自定义UDAF(弱类型  主要应用在SQL风格的DF查询)
    *
    * */
    
    object SparkSQL05_UDAF {
      def main(args: Array[String]): Unit = {
        val conf: SparkConf = new SparkConf().setAppName("SparkSQL05_UDAF").setMaster("local[*]")
        val sparkSession: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    
        val df: DataFrame = sparkSession.read.json("D:\IdeaProjects\spark_test\input\test.json")
        df.createOrReplaceTempView("user")
    
        val myAvg1 = new MyAvg1
        sparkSession.udf.register("MyAvg1", myAvg1)
    
        sparkSession.sql("select MyAvg1(age) from user").show()
    
        sparkSession.stop()
    
      }
    }
    
    //自定义UDAF函数(弱类型)
    class MyAvg1 extends UserDefinedAggregateFunction{
    
      //聚合函数的输入数据的类型
      override def inputSchema: StructType = {
        StructType(Array(StructField("age",IntegerType)))
      }
    
      //缓存数据的类型
      override def bufferSchema: StructType = {
        StructType(Array(StructField("sum",LongType),StructField("count",LongType)))
      }
    
      //聚合函数返回的数据类型
      override def dataType: DataType = DoubleType
    
      //稳定性  默认不处理,直接返回true    相同输入是否会得到相同的输出
      override def deterministic: Boolean = true
    
      //初始化  缓存设置到初始状态
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        //让缓存中年龄总和归0
        buffer(0) = 0L
        //让缓存中总人数归0
        buffer(1) = 0L
      }
    
      //更新缓存数据
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        if(!buffer.isNullAt(0)){
          buffer(0) = buffer.getLong(0) + input.getInt(0)
          buffer(1) = buffer.getLong(1) + 1L
        }
      }
    
      //分区间的合并
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
      }
    
      //计算逻辑
      override def evaluate(buffer: Row): Any = {
        buffer.getLong(0).toDouble/buffer.getLong(1)
      }
    }
    

    3.自定义UDTF(强类型DateSet)

    代码

    import org.apache.spark.SparkConf
    import org.apache.spark.sql.expressions.Aggregator
    import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, SparkSession, TypedColumn}
    
    object SparkSQL06_UDAF {
      def main(args: Array[String]): Unit = {
        val conf: SparkConf = new SparkConf().setAppName("SparkSQL06_UDAF").setMaster("local[*]")
        val sparkSession: SparkSession = SparkSession.builder().config(conf).getOrCreate()
        import sparkSession.implicits._
    
        val df: DataFrame = sparkSession.read.json("D:\IdeaProjects\spark_test\input\test.json")
        df.createOrReplaceTempView("user")
    
        val myAvg2 = new MyAvg2
    
        val column: TypedColumn[User01, Double] = myAvg2.toColumn
    
        val ds: Dataset[User01] = df.as[User01]
        ds.select(column).show()
    
        sparkSession.stop()
      }
    }
    
    //输入类型的样例类
    case class User01(username:String,age:Long)
    //缓存类型,由于设计到buffer计算,注意添加var类型
    case class AgeBuffer(var sum:Long,var count:Long)
    
    //自定义UDAF函数(强类型)
    //* @tparam IN 输入数据类型
    //* @tparam BUF 缓存数据类型
    //* @tparam OUT 输出结果数据类型
    class MyAvg2 extends Aggregator[User01,AgeBuffer,Double] {
      //对缓存数据进行初始化
      override def zero: AgeBuffer = {
        AgeBuffer(0L,0L)
      }
    
      //对当前分区内数据进行聚合
      override def reduce(b: AgeBuffer, a: User01): AgeBuffer = {
        b.sum += a.age
        b.count += 1L
        b
      }
    
      //分区间合并
      override def merge(b1: AgeBuffer, b2: AgeBuffer): AgeBuffer = {
        b1.sum += b2.sum
        b1.count += b2.count
        b1
      }
    
      //返回计算结果
      override def finish(reduction: AgeBuffer): Double = {
        reduction.sum.toDouble/reduction.count.toDouble
      }
    
      //DataSet的编码以及解码器  ,用于进行序列化,固定写法
      //用户自定义Ref类型  product       系统值类型,根据具体类型进行选择
      override def bufferEncoder: Encoder[AgeBuffer] = {
        Encoders.product
      }
    
      override def outputEncoder: Encoder[Double] = {
        Encoders.scalaDouble
      }
    }
    

      

  • 相关阅读:
    从新浪财经获取金融新闻类数据并进行打分计算
    SQL窗口函数的用法总结
    从新浪财经获取金融新闻类数据并保存到MySQL
    [ZJOI2015]幻想乡战略游戏
    二次剩余入门
    [多校赛20210406]迫害 DJ
    [NOI Online 2021 提高组] 愤怒的小N
    [NOI Online 2021 提高组] 岛屿探险
    「UNR #3」百鸽笼
    [ZJOI2019]开关
  • 原文地址:https://www.cnblogs.com/ttyypjt/p/14756904.html
Copyright © 2011-2022 走看看