zoukankan      html  css  js  c++  java
  • Spark(十三)【SparkSQL自定义UDF/UDAF函数】

    一.UDF(一进一出)

    步骤

    ① 注册UDF函数,可以使用匿名函数。

    ② 在sql查询的时候使用自定义的UDF。

    示例

    import org.apache.spark.sql.{DataFrame, SparkSession}
    
    /**
     * @description: UDF一进一出
     * @author: HaoWu
     * @create: 2020年08月09日
     */
    object UDF_Test {
      def main(args: Array[String]): Unit = {
        //创建SparkSession
        val session: SparkSession = SparkSession.builder
          .master("local[*]")
          .appName("MyApp")
          .getOrCreate()
        //注册UDF
        session.udf.register("addHello",(name:String) => "hello:"+name)
        //读取json格式文件{"name":"zhangsan","age":20},创建DataFrame
        val df: DataFrame = session.read.json("input/1.txt")
        //创建临时视图:person
        df.createOrReplaceTempView("person")
        //查询的时候使用UDF
        session.sql(
          """select
            |addHello(name),
            |age
            |from person
            |""".stripMargin).show
      }
    }
    

    结果

    |addHello(name)|age|
    +--------------+---+
    |hello:zhangsan| 20|
    |    hello:lisi| 30|
    +--------------+---+
    

    二.UDAF(多近一出)

    spark2.X 实现方式

    2.X版本:UserDefinedAggregateFunction 无类型或弱类型

    步骤

    ①继承UserDefinedAggregateFunction,实现其中的方法

    ②创建函数对象,注册函数,在sql中使用

        //创建UDFA对象
        val avgDemo1: Avg_UDAF_Demo1 = new Avg_UDAF_Demo1
        //在spark中注册聚合函数
        spark.udf.register("ageDemo1", avgDemo1)
    
    案例

    需求:实现avg()聚合函数的功能,要求结果是Double类型

    代码实现

    ①继承UserDefinedAggregateFunction,实现其中的方法
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StructField, StructType}
    
    /**
     * @description: UDAF(多近一出):求age的平均值
     *              2.X 版本继承UserDefinedAggregateFunction类,弱类型
     *               非常类似累加器,aggregateByKey算子的操作,有个ZeroValue,不断将输入的值做归约操作,然后再赋值给ZeroValue
     * @author: HaoWu
     * @create: 2020年08月08日
     */
    class Avg_UDAF_Demo1 extends UserDefinedAggregateFunction {
      //聚合函数输入参数的数据类型,
      override def inputSchema = StructType(StructField("age", LongType) :: Nil)
    
      //聚合函数缓冲区中值的数据类型(sum,count)
      override def bufferSchema = StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
    
      //函数返回值的数据类型
      override def dataType = DoubleType
    
      //稳定性:对于相同的输入是否一直返回相同的输出,一般都是true
      override def deterministic = true
    
      //函数缓冲区初始化,就是ZeroValue清空
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        //缓存区看做一个数组,将每个元素置空
        //sum
        buffer(0) = 0L
        //count
        buffer(1) = 0L
    
      }
      //更新缓冲区中的数据->将输入的值和缓存区数据合并
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        //input是Row类型,通过getXXX(索引值)取数据
        if (!input.isNullAt(0)) {
          val age = input.getLong(0)
          buffer(0) = buffer.getLong(0) + age
          buffer(1) = buffer.getLong(1) + 1
        }
      }
      //合并缓冲区 (sum1,count1) + (sum2,count2) 合并
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
      }
      //计算最终结果
      override def evaluate(buffer: Row) = buffer.getLong(0).toDouble/buffer.getLong(1)
    }
    
    ②创建函数对象,注册函数,在sql中使用
    /**
     * @description: 实现集合函数avg的功能
     * @author: HaoWu
     * @create: 2020年08月13日
     */
    object UDAF_Test {
      def main(args: Array[String]): Unit = {
        
        //创建SparkSession
        val spark: SparkSession = SparkSession.builder
          .master("local[*]")
          .appName("MyApp")
          .getOrCreate()
        //读取json格式文件{"name":"zhangsan","age":20}
        val df: DataFrame = spark.read.json("input/1.txt")
        //创建临时视图:person
        df.createOrReplaceTempView("person")
        //创建UDFA对象
        val avgDemo1: Avg_UDAF_Demo1 = new Avg_UDAF_Demo1
        //在spark中注册聚合函数
        spark.udf.register("ageDemo1", avgDemo1)
        //查询的时候使用UDF
        spark.sql(
          """select
            |ageDemo1(age)
            |from person
            |""".stripMargin).show
      }
    }
    

    spark3.X实现方式

    3.x版本: 认为2.X继承UserDefinedAggregateFunction的方式过时,推荐继承Aggregator ,是强类型

    步骤

    ①继承Aggregator [-IN, BUF, OUT],声明泛型,实现其中的方法

        abstract class Aggregator[-IN, BUF, OUT]  
            IN: 输入的类型      
            BUF:  缓冲区类型     
            OUT: 输出的类型      
    

    ②创建函数对象,注册函数,在sql中使用

        //创建UDFA对象
        val avgDemo2: Avg_UDAF_Demo2 = new Avg_UDAF_Demo2
        //在spark中注册聚合函数
        spark.udf.register("myAvg",functions.udaf(avgDemo2))
    

    注意:2.X和3.X的注册方式不同

    案例

    需求:实现avg()聚合函数的功能,要求结果是Double类型

    代码实现

    ①继承Aggregator [-IN, BUF, OUT],声明泛型,实现其中的方法

    其中缓冲区数据用样例类进行封装。

    MyBuffer类

    /**
     * 定义MyBuffer样例类
     * @param sum  组数据sum和
     * @param count  组的数据个数
     */
    case class MyBuffer(var sum: Long, var count: Long)
    

    自定义UDAF函数

    import org.apache.spark.sql.Encoders
    import org.apache.spark.sql.expressions.Aggregator
    
    /**
     * @description: UDAF(多近一出):求age的平均值
     *              3.X Aggregator,强类型
     *               非常类似累加器,aggregateByKey算子的操作,有个ZeroValue,不断将输入的值做归约操作,然后再赋值给ZeroValue
     * @author: HaoWu
     * @create: 2020年08月08日
     */
    class Avg_UDAF_Demo2 extends Aggregator[Long, MyBuffer, Double] {
      //函数缓冲区初始化,就是ZeroValue清空
      override def zero = MyBuffer(0L, 0L)
    
      //将输入的值和缓存区数据合并
      override def reduce(b: MyBuffer, a: Long) = {
        b.sum = b.sum + a
        b.count = b.count + 1
        b
      }
    
      //合并缓冲区
      override def merge(b1: MyBuffer, b2: MyBuffer) = {
        b1.sum = b1.sum + b2.sum
        b1.count = b1.count + b2.count
        b1
      }
    
      //计算最终结果
      override def finish(reduction: MyBuffer) = reduction.sum.toDouble / reduction.count
    
      /* scala中
         常见的数据类型: Encoders.scalaXXX
         自定义的类型:ExpressionEncoder[T]() 返回 Encoder[T]
         样例类(都是Product类型): Encoders.product[T],返回Produce类型的Encoder!
                                                */
      //缓存区的Encoder类型
      override def bufferEncoder = Encoders.product[MyBuffer]
    
      //输出结果的Encoder类型
      override def outputEncoder = Encoders.scalaDouble
    }
    
    ②创建函数对象,注册函数,在sql中使用
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.{DataFrame, Row, SparkSession, functions}
    
    /**
     * @description: 实现集合函数avg的功能
     * @author: HaoWu
     * @create: 2020年08月13日
     */
    object UDAF_Test {
      def main(args: Array[String]): Unit = {
    
        //创建SparkSession
        val spark: SparkSession = SparkSession.builder
          .master("local[*]")
          .appName("MyApp")
          .getOrCreate()
        //读取json格式文件{"name":"zhangsan","age":20}
        val df: DataFrame = spark.read.json("input/1.txt")
        //创建临时视图:person
        df.createOrReplaceTempView("person")
        //创建UDFA对象
        val avgDemo2: Avg_UDAF_Demo2 = new Avg_UDAF_Demo2
        //在spark中注册聚合函数
        spark.udf.register("myAvg",functions.udaf(avgDemo2))
        //查询的时候使用UDF
        spark.sql(
          """select
            |myAvg(age)
            |from person
            |""".stripMargin).show
      }
    }
    
  • 相关阅读:
    RTMP协议在线教育课堂web视频直播点播平台EasyDSS鉴权模块优化说明
    RTMP协议在线教育课堂web视频直播点播平台EasyDSS在大量设备开启录像后为什么会导致系统卡死?
    RTMP协议视频直播点播智能分析平台EasyDSS优化视频水印生成效率参考
    互联网在线课堂直播点播视频平台EasyDSS访问页面报NO DSS SERVICE如何排查?
    RTMP直播点播平台EasyDSS下载录像文件为什么会提示:最大播放下载录像间隔是3小时?
    RTMP协议互联网教育课堂直播点播系统EasyDSS获取直播信息优化设计方案介绍
    如何将RTMP协议视频直播点播平台EasyDSS录像文件存储在其他的空闲磁盘内?
    POJ 3069 Saruman's Army 贪心
    POJ3617 Best Cow line 简单题
    POJ 1852 Ants 思维题 简单题
  • 原文地址:https://www.cnblogs.com/wh984763176/p/13497053.html
Copyright © 2011-2022 走看看