zoukankan      html  css  js  c++  java
  • sparksql 自定义用户函数(UDF)

    自定义用户函数有两种方式,区别:是否使用强类型,参考demo:https://github.com/asker124143222/spark-demo

    1、不使用强类型,继承UserDefinedAggregateFunction

    package com.home.spark
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.{DataFrame, Row, SparkSession}
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    
    
    object Ex_sparkUDAF {
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf(true).setAppName("spark udf").setMaster("local[*]")
        val spark = SparkSession.builder().config(conf).getOrCreate()
    
    
        //自定义聚合函数
        //创建聚合函数对象
        val myUdaf = new MyAgeAvgFunc
    
        //注册自定义函数
        spark.udf.register("ageAvg",myUdaf)
    
        //使用聚合函数
        val frame: DataFrame = spark.read.json("input/userinfo.json")
        frame.createOrReplaceTempView("userinfo")
        spark.sql("select ageAvg(age) from userinfo").show()
    
        spark.stop()
      }
    }
    
    //声明自定义函数
    //实现对年龄的平均,数据如:{ "name": "tom", "age" : 20}
    class MyAgeAvgFunc extends UserDefinedAggregateFunction {
      //函数输入的数据结构,本例中只有年龄是输入数据
      override def inputSchema: StructType = {
        new StructType().add("age", LongType)
      }
    
      //计算时的数据结构(缓冲区)
      // 本例中有要计算年龄平均值,必须有两个计算结构,一个是年龄总计(sum),一个是年龄个数(count)
      override def bufferSchema: StructType = {
        new StructType().add("sum", LongType).add("count", LongType)
      }
    
      //函数返回的数据类型
      override def dataType: DataType = DoubleType
    
      //函数是否稳定
      override def deterministic: Boolean = true
    
      //计算前缓冲区的初始化,结构类似数组,这里缓冲区与之前定义的bufferSchema顺序一致
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        //sum
        buffer(0) = 0L
        //count
        buffer(1) = 0L
      }
    
      //根据查询结果更新缓冲区数据,input是每次进入的数据,其数据结构与之前定义的inputSchema相同
      //本例中每次输入的数据只有一个就是年龄
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        if(input.isNullAt(0)) return
        //sum
        buffer(0) = buffer.getLong(0) + input.getLong(0)
    
        //count,每次来一个数据加1
        buffer(1) = buffer.getLong(1) + 1
      }
    
      //将多个节点的缓冲区合并到一起(因为spark是分布式的)
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        //sum
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    
        //count
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
      }
    
      //计算最终结果,本例中就是(sum / count)
      override def evaluate(buffer: Row): Any = {
        buffer.getLong(0).toDouble / buffer.getLong(1)
      }
    }

    2、使用强类型,

    package com.home.spark
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql._
    import org.apache.spark.sql.expressions.Aggregator
    
    
    object Ex_sparkUDAF2 {
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf(true).setAppName("spark udf class").setMaster("local[*]")
        val spark = SparkSession.builder().config(conf).getOrCreate()
    
        //rdd转换成df或者ds需要SparkSession实例的隐式转换
        //导入隐式转换,注意这里的spark不是包名,而是SparkSession的对象名
        import spark.implicits._
    
        //创建聚合函数对象
        val myAvgFunc = new MyAgeAvgClassFunc
        val avgCol: TypedColumn[UserBean, Double] = myAvgFunc.toColumn.name("avgAge")
        val frame = spark.read.json("input/userinfo.json")
        val userDS: Dataset[UserBean] = frame.as[UserBean]
        //应用函数
        userDS.select(avgCol).show()
    
        spark.stop()
      }
    }
    
    
    case class UserBean(name: String, age: BigInt)
    
    case class AvgBuffer(var sum: BigInt, var count: Int)
    
    //声明用户自定义函数(强类型方式)
    //继承Aggregator,设定泛型
    //实现方法
    class MyAgeAvgClassFunc extends Aggregator[UserBean, AvgBuffer, Double] {
      //初始化缓冲区
      override def zero: AvgBuffer = {
        AvgBuffer(0, 0)
      }
    
      //聚合数据
      override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
        if(a.age == null) return b
        b.sum = b.sum + a.age
        b.count = b.count + 1
    
        b
      }
    
      //缓冲区合并操作
      override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
        b1.sum = b1.sum + b2.sum
        b1.count = b1.count + b2.count
    
        b1
      }
    
      //完成计算
      override def finish(reduction: AvgBuffer): Double = {
        reduction.sum.toDouble / reduction.count
      }
    
      override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product
    
      override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    }

    继承Aggregator

  • 相关阅读:
    hdu1238 Substrings
    CCF试题:高速公路(Targin)
    hdu 1269 迷宫城堡(Targin算法)
    hdu 1253 胜利大逃亡
    NYOJ 55 懒省事的小明
    HDU 1024 Max Sum Plus Plus
    HDU 1087 Super Jumping! Jumping! Jumping!
    HDU 1257 最少拦截系统
    HDU 1069 Monkey and Banana
    HDU 1104 Remainder
  • 原文地址:https://www.cnblogs.com/asker009/p/12092684.html
Copyright © 2011-2022 走看看