zoukankan      html  css  js  c++  java
  • 048 SparkSQL自定义UDAF函数

    一:程序

    1.需求

      实现一个求平均值的UDAF。

      这里保留Double格式化,在完成求平均值后与系统的AVG进行对比,观察正确性。

    2.SparkSQLUDFDemo程序

     1 package com.scala.it
     2 
     3 import org.apache.spark.sql.hive.HiveContext
     4 import org.apache.spark.{SparkConf, SparkContext}
     5 
     6 import scala.math.BigDecimal.RoundingMode
     7 
     8 object SparkSQLUDFDemo {
     9   def main(args: Array[String]): Unit = {
    10     val conf = new SparkConf()
    11       .setMaster("local[*]")
    12       .setAppName("udf")
    13     val sc = SparkContext.getOrCreate(conf)
    14     val sqlContext = new HiveContext(sc)
    15 
    16     // ==================================
    17     // 写一个Double数据格式化的自定义函数(给定保留多少位小数部分)
    18     sqlContext.udf.register(
    19       "doubleValueFormat", // 自定义函数名称
    20       (value: Double, scale: Int) => {
    21         // 自定义函数处理的代码块
    22         BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_DOWN).doubleValue()
    23       })
    24 
    25     // 自定义UDAF
    26     sqlContext.udf.register("selfAvg", AvgUDAF)
    27 
    28     sqlContext.sql(
    29       """
    30         |SELECT
    31         |  deptno,
    32         |  doubleValueFormat(AVG(sal), 2) AS avg_sal,
    33         |  doubleValueFormat(selfAvg(sal), 2) AS self_avg_sal
    34         |FROM hadoop09.emp
    35         |GROUP BY deptno
    36       """.stripMargin).show()
    37 
    38   }
    39 }

    3.AvgUDAF程序

     1 package com.scala.it
     2 
     3 import org.apache.spark.sql.Row
     4 import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
     5 import org.apache.spark.sql.types._
     6 
     7 
     8 object AvgUDAF extends UserDefinedAggregateFunction{
     9   override def inputSchema: StructType = {
    10     // 给定UDAF的输出参数类型
    11     StructType(
    12       StructField("sal", DoubleType) :: Nil
    13     )
    14   }
    15 
    16   override def bufferSchema: StructType = {
    17     // 在计算过程中会涉及到的缓存数据类型
    18     StructType(
    19       StructField("total_sal", DoubleType) ::
    20         StructField("count_sal", LongType) :: Nil
    21     )
    22   }
    23 
    24   override def dataType: DataType = {
    25     // 给定该UDAF返回的数据类型
    26     DoubleType
    27   }
    28 
    29   override def deterministic: Boolean = {
    30     // 主要用于是否支持近似查找,如果为false:表示支持多次查询允许结果不一样,为true表示结果必须一样
    31     true
    32   }
    33 
    34   override def initialize(buffer: MutableAggregationBuffer): Unit = {
    35     // 初始化 ===> 初始化缓存数据
    36     buffer.update(0, 0.0) // 初始化total_sal
    37     buffer.update(1, 0L) // 初始化count_sal
    38   }
    39 
    40   override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    41     // 根据输入的数据input,更新缓存buffer的内容
    42     // 获取输入的sal数据
    43     val inputSal = input.getDouble(0)
    44 
    45     // 获取缓存中的数据
    46     val totalSal = buffer.getDouble(0)
    47     val countSal = buffer.getLong(1)
    48 
    49     // 更新缓存数据
    50     buffer.update(0, totalSal + inputSal)
    51     buffer.update(1, countSal + 1L)
    52   }
    53 
    54   override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    55     // 当两个分区的数据需要进行合并的时候,该方法会被调用
    56     // 功能:将buffer2中的数据合并到buffer1中
    57     // 获取缓存区数据
    58     val buf1Total = buffer1.getDouble(0)
    59     val buf1Count = buffer1.getLong(1)
    60 
    61     val buf2Total = buffer2.getDouble(0)
    62     val buf2Count = buffer2.getLong(1)
    63 
    64     // 更新缓存区
    65     buffer1.update(0, buf1Total + buf2Total)
    66     buffer1.update(1, buf1Count + buf2Count)
    67   }
    68 
    69   override def evaluate(buffer: Row): Any = {
    70     // 求返回值
    71     buffer.getDouble(0) / buffer.getLong(1)
    72   }
    73 }

    4.效果

      

    二:知识点

    1.udf注册

      

    2.解释上面的update

      重要的是两个参数的意思,不然程序有些看不懂。

      所以,程序的意思是,第一位存储总数,第二位存储个数。

      

    3.还要解释一个StructType的生成

      在以前的程序中,是使用Array来生成的。如:

        

      在上面的程序中,不是这种方式,使用集合的方式。

        

  • 相关阅读:
    html5 java多图片上传
    ajax post form表单
    java获取图片文件返回地址
    教你使用servlet拦截器,放行不需要拦截的内容
    实用的request接收值的工具类
    spring3的定时执行任务
    centos7.4无法启动之找不到EFIBOOTgrubx64.efi
    redhat7.2上搭建网易、epel的yum repo
    python2和python3中的关键字的区别--keyword模块
    搭建lamp的脚本
  • 原文地址:https://www.cnblogs.com/juncaoit/p/9386215.html
Copyright © 2011-2022 走看看