zoukankan      html  css  js  c++  java
  • spark自定义函数之——UDAF使用详解及代码示例

    UDAF简介

    UDAF(User Defined Aggregate Function)即用户定义的聚合函数,聚合函数和普通函数的区别是什么呢,普通函数是接受一行输入产生一个输出,聚合函数是接受一组(一般是多行)输入然后产生一个输出,即将一组的值想办法聚合一下。

    UDAF的误区

    我们可能下意识的认为UDAF是需要和group by一起使用的,实际上UDAF可以跟group by一起使用,也可以不跟group by一起使用,这个其实比较好理解,联想到mysql中的max、min等函数,可以:

    select max(foo) from foobar group by bar;

    表示根据bar字段分组,然后求每个分组的最大值,这时候的分组有很多个,使用这个函数对每个分组进行处理,也可以:

    select max(foo) from foobar group by bar;

    这种情况可以将整张表看做是一个分组,然后在这个分组(实际上就是一整张表)中求最大值。所以聚合函数实际上是对分组做处理,而不关心分组中记录的具体数量。

    UDAF使用

    UDAF 的使用方法有这两种

    • 继承UserDefinedAggregateFunction
    • 继承Aggregator

    下面介绍两种UDAF的实现

    方法一:继承UserDefinedAggregateFunction

    使用UserDefinedAggregateFunction的套路:

    1. 自定义类继承UserDefinedAggregateFunction,对每个阶段方法做实现

    2. 在spark中注册UDAF,为其绑定一个名字

    3. 然后就可以在sql语句中使用上面绑定的名字调用

    下面写一个计算平均值的UDAF例子,首先定义一个类继承UserDefinedAggregateFunction:

    package cc11001100.spark.sql.udaf
     
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
     
    object AverageUserDefinedAggregateFunction extends UserDefinedAggregateFunction {
     
      // 聚合函数的输入数据结构
      override def inputSchema: StructType = StructType(StructField("input", LongType) :: Nil)
     
      // 缓存区数据结构
      override def bufferSchema: StructType = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
     
      // 聚合函数返回值数据结构
      override def dataType: DataType = DoubleType
     
      // 聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
      override def deterministic: Boolean = true
     
      // 初始化缓冲区
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = 0L
        buffer(1) = 0L
      }
     
      // 给聚合函数传入一条新数据进行处理
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        if (input.isNullAt(0)) return
        buffer(0) = buffer.getLong(0) + input.getLong(0)
        buffer(1) = buffer.getLong(1) + 1
      }
     
      // 合并聚合函数缓冲区
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
      }
     
      // 计算最终结果
      override def evaluate(buffer: Row): Any = buffer.getLong(0).toDouble / buffer.getLong(1)
     
    }

    然后注册并使用它:

    package cc11001100.spark.sql.udaf
     
    import org.apache.spark.sql.SparkSession
     
    object SparkSqlUDAFDemo_001 {
     
      def main(args: Array[String]): Unit = {
     
        val spark = SparkSession.builder().master("local[*]").appName("SparkStudy").getOrCreate()
        spark.read.json("data/user").createOrReplaceTempView("v_user")
        spark.udf.register("u_avg", AverageUserDefinedAggregateFunction)
        // 将整张表看做是一个分组对求所有人的平均年龄
        spark.sql("select count(1) as count, u_avg(age) as avg_age from v_user").show()
        // 按照性别分组求平均年龄
        spark.sql("select sex, count(1) as count, u_avg(age) as avg_age from v_user group by sex").show()
     
      }
     
    }

    结果

    //使用到的数据集
    {"id": 1001, "name": "foo", "sex": "man", "age": 20}
    {"id": 1002, "name": "bar", "sex": "man", "age": 24}
    {"id": 1003, "name": "baz", "sex": "man", "age": 18}
    {"id": 1004, "name": "foo1", "sex": "woman", "age": 17}
    {"id": 1005, "name": "bar2", "sex": "woman", "age": 19}
    {"id": 1006, "name": "baz3", "sex": "woman", "age": 20}
    
    //运行结果
    +-----+--------+
    | count|avg_age|
    +-----+--------+
    | 6 | 19.6666|
    +-----+--------+
    +-----+--------+---------+
    | sex | count | avg_age |
    +-----+--------+---------+
    | man| 19.6666|20.666666|

    |woman| 19.6666|20.666666|

    +-----+--------+---------+

    方法二:继承Aggregator

    还有另一种方式就是继承Aggregator这个类,优点是可以带类型:

    package cc11001100.spark.sql.udaf
     
    import org.apache.spark.sql.expressions.Aggregator
    import org.apache.spark.sql.{Encoder, Encoders}
     
    /**
      * 计算平均值
      *
      */
    object AverageAggregator extends Aggregator[User, Average, Double] {
     
      // 初始化buffer
      override def zero: Average = Average(0L, 0L)
     
      // 处理一条新的记录
      override def reduce(b: Average, a: User): Average = {
        b.sum += a.age
        b.count += 1L
        b
      }
     
      // 合并聚合buffer
      override def merge(b1: Average, b2: Average): Average = {
        b1.sum += b2.sum
        b1.count += b2.count
        b1
      }
     
      // 减少中间数据传输
      override def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
     
      override def bufferEncoder: Encoder[Average] = Encoders.product
     
      // 最终输出结果的类型
      override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
     
    }
     
    /**
      * 计算平均值过程中使用的Buffer
      *
      * @param sum
      * @param count
      */
    case class Average(var sum: Long, var count: Long) {
    }
     
    case class User(id: Long, name: String, sex: String, age: Long) {
    }

    调用:

    package cc11001100.spark.sql.udaf
     
    import org.apache.spark.sql.SparkSession
     
    object AverageAggregatorDemo_001 {
     
      def main(args: Array[String]): Unit = {
     
        val spark = SparkSession.builder().master("local[*]").appName("SparkStudy").getOrCreate()
        import spark.implicits._
        val user = spark.read.json("data/user").as[User]
        user.select(AverageAggregator.toColumn.name("avg")).show()
     
      }
     
    }
    //运行结果
    +--------+
    | avg |
    +--------+
    | 19.6666|
    +--------+

     

  • 相关阅读:
    js作用域
    正则表达式方法
    for-in
    关于this
    由indexOf引发的一系列探究和思考
    Html第一个程序
    转:1.ASP.Net MVC开发基础学习笔记;2.HtmlHelper与扩展方法
    转:单例模式
    转:ORM框架EF(Entity framework)实体框架
    C#lambda表达式
  • 原文地址:https://www.cnblogs.com/yyy-blog/p/10280795.html
Copyright © 2011-2022 走看看