zoukankan      html  css  js  c++  java
  • sparksql 自定义用户函数(UDF)

    自定义用户函数有两种方式,区别:是否使用强类型,参考demo:https://github.com/asker124143222/spark-demo

    1、不使用强类型,继承UserDefinedAggregateFunction

    package com.home.spark
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.{DataFrame, Row, SparkSession}
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    
    
    object Ex_sparkUDAF {
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf(true).setAppName("spark udf").setMaster("local[*]")
        val spark = SparkSession.builder().config(conf).getOrCreate()
    
    
        //自定义聚合函数
        //创建聚合函数对象
        val myUdaf = new MyAgeAvgFunc
    
        //注册自定义函数
        spark.udf.register("ageAvg",myUdaf)
    
        //使用聚合函数
        val frame: DataFrame = spark.read.json("input/userinfo.json")
        frame.createOrReplaceTempView("userinfo")
        spark.sql("select ageAvg(age) from userinfo").show()
    
        spark.stop()
      }
    }
    
    //声明自定义函数
    //实现对年龄的平均,数据如:{ "name": "tom", "age" : 20}
    class MyAgeAvgFunc extends UserDefinedAggregateFunction {
      //函数输入的数据结构,本例中只有年龄是输入数据
      override def inputSchema: StructType = {
        new StructType().add("age", LongType)
      }
    
      //计算时的数据结构(缓冲区)
      // 本例中有要计算年龄平均值,必须有两个计算结构,一个是年龄总计(sum),一个是年龄个数(count)
      override def bufferSchema: StructType = {
        new StructType().add("sum", LongType).add("count", LongType)
      }
    
      //函数返回的数据类型
      override def dataType: DataType = DoubleType
    
      //函数是否稳定
      override def deterministic: Boolean = true
    
      //计算前缓冲区的初始化,结构类似数组,这里缓冲区与之前定义的bufferSchema顺序一致
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        //sum
        buffer(0) = 0L
        //count
        buffer(1) = 0L
      }
    
      //根据查询结果更新缓冲区数据,input是每次进入的数据,其数据结构与之前定义的inputSchema相同
      //本例中每次输入的数据只有一个就是年龄
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        if(input.isNullAt(0)) return
        //sum
        buffer(0) = buffer.getLong(0) + input.getLong(0)
    
        //count,每次来一个数据加1
        buffer(1) = buffer.getLong(1) + 1
      }
    
      //将多个节点的缓冲区合并到一起(因为spark是分布式的)
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        //sum
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    
        //count
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
      }
    
      //计算最终结果,本例中就是(sum / count)
      override def evaluate(buffer: Row): Any = {
        buffer.getLong(0).toDouble / buffer.getLong(1)
      }
    }

    2、使用强类型,

    package com.home.spark
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql._
    import org.apache.spark.sql.expressions.Aggregator
    
    
    object Ex_sparkUDAF2 {
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf(true).setAppName("spark udf class").setMaster("local[*]")
        val spark = SparkSession.builder().config(conf).getOrCreate()
    
        //rdd转换成df或者ds需要SparkSession实例的隐式转换
        //导入隐式转换,注意这里的spark不是包名,而是SparkSession的对象名
        import spark.implicits._
    
        //创建聚合函数对象
        val myAvgFunc = new MyAgeAvgClassFunc
        val avgCol: TypedColumn[UserBean, Double] = myAvgFunc.toColumn.name("avgAge")
        val frame = spark.read.json("input/userinfo.json")
        val userDS: Dataset[UserBean] = frame.as[UserBean]
        //应用函数
        userDS.select(avgCol).show()
    
        spark.stop()
      }
    }
    
    
    case class UserBean(name: String, age: BigInt)
    
    case class AvgBuffer(var sum: BigInt, var count: Int)
    
    //声明用户自定义函数(强类型方式)
    //继承Aggregator,设定泛型
    //实现方法
    class MyAgeAvgClassFunc extends Aggregator[UserBean, AvgBuffer, Double] {
      //初始化缓冲区
      override def zero: AvgBuffer = {
        AvgBuffer(0, 0)
      }
    
      //聚合数据
      override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
        if(a.age == null) return b
        b.sum = b.sum + a.age
        b.count = b.count + 1
    
        b
      }
    
      //缓冲区合并操作
      override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
        b1.sum = b1.sum + b2.sum
        b1.count = b1.count + b2.count
    
        b1
      }
    
      //完成计算
      override def finish(reduction: AvgBuffer): Double = {
        reduction.sum.toDouble / reduction.count
      }
    
      override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product
    
      override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    }

    继承Aggregator

  • 相关阅读:
    Eureka 系列(04)客户端源码分析
    Eureka 系列(03)Spring Cloud 自动装配原理
    Eureka 系列(02)Eureka 一致性协议
    Eureka 系列(01)最简使用姿态
    Feign 系列(05)Spring Cloud OpenFeign 源码解析
    python 线程,进程与协程
    Python IO多路复用
    python 作用域
    python 网络编程:socket(二)
    python 网络编程:socket
  • 原文地址:https://www.cnblogs.com/asker009/p/12092684.html
Copyright © 2011-2022 走看看