zoukankan      html  css  js  c++  java
  • 详解Spark sql用户自定义函数:UDF与UDAF

    UDAF = USER DEFINED AGGREGATION FUNCTION

           Spark sql提供了丰富的内置函数供猿友们使用,辣为何还要用户自定义函数呢?实际的业务场景可能很复杂,内置函数hold不住,所以Spark sql提供了可扩展的内置函数接口:哥们,你的业务太变态了,我满足不了你,自己按照我的规范去定义一个sql函数,该怎么折腾就怎么折腾! 
    例如,MySQL数据库中有一张task表,共两个字段taskid (任务ID)与taskParam(JSON格式的任务请求参数)。简单起见,这里只列出一条记录:

    taskid 1              taskParam {"endAge":["50"],"endDate":["2016-06-21"],"startAge":["10"],"startDate":["2016-06-21"]}

    假设应用程序已经读取了MySQL中这张表的记录,并通过 DateFrame注册成了一张临时表 task。问题来了:怎么获取taskParam中startAge的第一个值呢?

    sqlContext.sql("select taskid,getJsonFieldUDF(taskParm,'startAge')")

    这个时候,我们就需要自定义一个UDF函数了,取名getJsonFieldUDF。Java版本的代码大致如下:

    package cool.pengych.sparker.product;
    import org.apache.spark.sql.api.java.UDF2;
    import com.alibaba.fastjson.JSONObject;
    /**
     * 用户自定义函数
     * @author pengyucheng
     */
    public class GetJsonObjectUDF implements UDF2<String,String,String>
    {
        /**
         * 获取数组类型json字符串中某一字段的值
         */
        @Override
        public String call(String json, String field) throws Exception 
        {
            try
            {
                JSONObject jsonObject = JSONObject.parseObject(json);
                return jsonObject.getJSONArray(field).getString(0);
            }
            catch(Exception e)
            {
                e.printStackTrace();
            }
            return null;
        }
    }

    这样的需求在实际项目中是很普遍的:请求参数经常以json格式存储在数据库中。这里还是先以Scala实现一个简单的hello world级别的小样为例,来体验udf与udaf的使用好了。

    问题

    将如下数组:

    val bigData = Array("Spark","Hadoop","Flink","Spark","Hadoop","Flink",
    "Spark","Hadoop","Flink","Spark","Hadoop","Flink")

    中的字符分组聚合并计算出每个字符的长度及字符出现的个数。正常结果 
    如下:

    +------+-----+------+
    |  name|count|length|
    +------+-----+------+
    | Spark|    4|     5|
    | Flink|    4|     5|
    |Hadoop|    4|     6|
    +------+-----+------+

    注:‘spark’ 这个字符的长度为5 ,共出现了4次。

    分析

      • 自定义个一个求字符串长度的函数 
        自定义的sql函数,与scala中的普通函数一样,只不过在使用上前者需要先在sqlContext中进行注册。
      • 自定义一个聚合函数 
        按照字符串名称分组后,调用自定义的聚合函数实现累加。 
        啊,好抽象,直接看代码吧!

    代码

    package main.scala
    
    import org.apache.spark.SparkContext
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.hive.HiveContext
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types.StructType
    import org.apache.spark.sql.types.StructField
    import org.apache.spark.sql.types.StringType
    import org.apache.spark.sql.SQLContext
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
    import org.apache.spark.sql.types.IntegerType
    import org.apache.spark.sql.types.DataType
    import org.apache.spark.sql.expressions.MutableAggregationBuffer
    
    /**
     * Spark SQL UDAS:user defined aggregation function 
     * UDF: 函数的输入是一条具体的数据记录,实现上讲就是普通的scala函数-只不过需要注册
     * UDAF:用户自定义的聚合函数,函数本身作用于数据集合,能够在具体操作的基础上进行自定义操作
     */
    object SparkSQLUDF {
    
      def main(args: Array[String]): Unit = {
    
        val conf = new SparkConf().setMaster("local[*]").setAppName("SparkSQLWindowFunctionOps")
        val sc = new SparkContext(conf)
    
        val hiveContext = new SQLContext(sc)
    
        val bigData = Array("Spark","Hadoop","Flink","Spark","Hadoop","Flink","Spark","Hadoop","Flink","Spark","Hadoop","Flink")
        val bigDataRDD = sc.parallelize(bigData)
    
         val bigDataRowRDD = bigDataRDD.map(line => Row(line))
         val structType = StructType(Array(StructField("name",StringType,true)))
         val bigDataDF = hiveContext.createDataFrame(bigDataRowRDD, structType)
    
         bigDataDF.registerTempTable("bigDataTable")
    
        /*
         * 通过HiveContext注册UDF,在scala2.10.x版本UDF函数最多可以接受22个输入参数
         */
         hiveContext.udf.register("computeLength",(input:String) => input.length)
         hiveContext.sql("select name,computeLength(name)  as length from bigDataTable").show
    
         //while(true){}
    
         hiveContext.udf.register("wordCount",new MyUDAF)
         hiveContext.sql("select name,wordCount(name) as count,computeLength(name) as length from bigDataTable group by name ").show
      }
    }
    
    /**
     * 用户自定义函数
     */
     class MyUDAF extends UserDefinedAggregateFunction
     {
      /**
       * 指定具体的输入数据的类型
       * 自段名称随意:Users can choose names to identify the input arguments - 这里可以是“name”,或者其他任意串
       */
      override def inputSchema:StructType = StructType(Array(StructField("name",StringType,true)))
    
      /**
       * 在进行聚合操作的时候所要处理的数据的中间结果类型
       */
      override def bufferSchema:StructType = StructType(Array(StructField("count",IntegerType,true)))
    
      /**
       * 返回类型
       */
      override def dataType:DataType = IntegerType
    
      /**
       * whether given the same input,
       * always return the same output
       * true: yes 
       */
      override def deterministic:Boolean = true
    
      /**
       * Initializes the given aggregation buffer
       */
      override def initialize(buffer:MutableAggregationBuffer):Unit = {buffer(0)=0}
    
      /**
       * 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
       * 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner
       */
      override def update(buffer:MutableAggregationBuffer,input:Row):Unit={
        buffer(0) = buffer.getInt(0)+1
      }
    
      /**
       * 最后在分布式节点进行local reduce完成后需要进行全局级别的merge操作
       */
      override def merge(buffer1:MutableAggregationBuffer,buffer2:Row):Unit={
        buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0)
      }
    
      /**
       * 返回UDAF最后的计算结果
       */
      override def evaluate(buffer:Row):Any = buffer.getInt(0)
    }

    执行结果:

    16/06/29 19:30:24 INFO DAGScheduler: ResultStage 5 (show at SparkSQLUDF.scala:48) finished in 1.625 s
    +------+-----+------+
    |  name|count|length|
    +------+-----+------+
    | Spark|    4|     5|
    | Flink|    4|     5|
    |Hadoop|    4|     6|
    +------+-----+------+
    
    16/06/29 19:30:24 INFO DAGScheduler: Job 3 finished: show at SparkSQLUDF.scala:48, took 1.717878 s

    总结

      • 呼叫spark大神升级udaf实现 
        为了自己实现一个sql聚合函数,我需要继承UserDefinedAggregateFunction并实现8个抽象方法!8个方法啊!what’s a disaster ! 然而,要想在sql中完成符合特定业务场景的聚合类(a = aggregation)功能,就得udaf。 
        怎么理解MutableAggregationBuffer呢?就是存储中间结果的,聚合就意味着多条记录的累加等操作。

      • udf与udaf注册语法

     hiveContext.udf.register("computeLength",(input:String) => input.length)
    
     hiveContext.udf.register("wordCount",new MyUDAF)
  • 相关阅读:
    Netty简单聊天室
    JDK环境变量配置
    EasyUI Tabs
    NIO(五)
    NIO(四)
    银行对公业务和对私业务
    mysql常用操作
    LInux安装MySQL5.7.24详情
    Python3 SMTP发送邮件
    linux下sendmail邮件系统安装详情
  • 原文地址:https://www.cnblogs.com/itboys/p/7281270.html
Copyright © 2011-2022 走看看