zoukankan      html  css  js  c++  java
  • Spark SQL UDF 函数(四)

    Spark 中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:

    • UDF(User-Defined-Function):即最基本的自定义函数,类似 to_char,to_date
    • UDAF(User- Defined Aggregation Funcation):用户自定义聚合函数,类似在group by之后使用的sum,avg
    • UDTF(User-Defined Table-Generating Functions):用户自定义生成函数,有点像stream里面的flatMap

    1. 初步使用 UDF 函数

    
    scala> val df = spark.read.json("hdfs://hadoop1:9000/people.json")
    df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]
    
    // 注册使用,toUpper 为函数名称
    scala> spark.udf.register("toUpper", (s: String) => s.toUpperCase)
    res15: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))
    
    scala> df.createOrReplaceTempView("people")
    
    scala> spark.sql("select toUpper(name), age from people").show
    +-----------------+----+
    |UDF:toUpper(name)| age|
    +-----------------+----+
    |          MICHAEL|null|
    |             ANDY|  30|
    |           JUSTIN|  19|
    +-----------------+----+
    

    2. 自定义UDAF 聚合函数

    package top.midworld.spark1031.create_df
    
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types.{DataType, DoubleType, StructField, StructType}
    import org.apache.spark.sql.{Row, SparkSession}
    
    // 样例类
    case class UserInfo(name: String, age: Double)
    
    object UDF1 {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder.appName("udf").master("local[2]").getOrCreate()
        val sc = spark.sparkContext
    
        import spark.implicits._
    
        val rdd = sc.textFile("hdfs://hadoop1:9000/people.txt").
          map(_.split(",")).
          map(x => UserInfo(x(0), x(1).trim.toDouble))
    
        val df = rdd.toDF()
        df.createOrReplaceTempView("user")
    	
          // 注册 udf 函数
        spark.udf.register("mySum", new MySum)
    
        spark.sql("select mySum(age) as age_sum from user").show()
    
        df.show()
        sc.stop()
        spark.stop()
    
    
      }
    }
    
    class MySum extends UserDefinedAggregateFunction {
      // 输入的数据类型:29/30/19
      override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)
    
      // 缓冲区的类型
      override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: Nil)
    
      // 最终聚合解结果的类型
      override def dataType: DataType = DoubleType
    
      // 相同的输入是否返回相同的输出,始终为 true
      override def deterministic: Boolean = true
    
      // 对缓冲区初始化
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        println("initialize===>" + buffer)    // initialize===>[null]
        // 对缓冲区集合初始化和
        buffer(0) = 0D
      }
    
      // 分区内聚合
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        println("update===>" + buffer)
        println("input===>" + input)
        /*
        update===>[0.0]
        update===>[0.0]
        input===>[19.0]
        input===>[29.0]
        update===>[29.0]
        input===>[30.0]
         */
    
        // 模式匹配输入数据类型
        input match {
          // double 类型
          case Row(age: Double) =>
            buffer(0) = buffer.getDouble(0) + age
    
          // 其他类型
          case _ =>
        }
      }
    
      // 分区间聚合
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        println("merge buffer1 ===>" + buffer1)
        println("merge buffer2 ===>" + buffer2)
        /*
        merge buffer1 ===>[0.0]
        merge buffer2 ===>[59.0]
        merge buffer1 ===>[59.0]
        merge buffer2 ===>[19.0]
         */
    
        // buffer1 + buffer2
        buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
      }
    
      // 返回最终的输出值
      override def evaluate(buffer: Row): Any = buffer.getDouble(0)
    }
    

    运行结果:

    +-------+
    |age_sum|
    +-------+
    |   78.0|
    +-------+
    
    +-------+----+
    |   name| age|
    +-------+----+
    |Michael|29.0|
    |   Andy|30.0|
    | Justin|19.0|
    +-------+----+
    

    求平均值

    class MyAvg extends UserDefinedAggregateFunction {
      // 输入的数据类型:29/30/19
      override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)
    
      // 缓冲区的类型
      override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: StructField("count", LongType) :: Nil)
    
      // 最终聚合解结果的类型
      override def dataType: DataType = DoubleType
    
      // 相同的输入是否返回相同的输出,始终为 true
      override def deterministic: Boolean = true
    
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = 0D
        buffer(1) = 0L
      }
    
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        input match {
          case Row(age: Double) =>
            buffer(0) = buffer.getDouble(0) + age
            buffer(1) = buffer.getLong(1) + 1L
          case _ =>
        }
      }
    
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer2 match {
          case Row(sum: Double, count: Long) =>
            buffer1(0) = buffer1.getDouble(0) + sum
            buffer1(1) = buffer2.getLong(1) + count
        }
      }
    
      override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getLong(1)
    }
    

    3. 开窗函数

    https://blog.csdn.net/sunxiaoju/article/details/103800028

    https://blog.csdn.net/liangzelei/article/details/80608302?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-4.no_search_link&spm=1001.2101.3001.4242.3

  • 相关阅读:
    Live2D 看板娘
    Live2D 看板娘
    Live2D 看板娘
    Live2D 看板娘
    Live2D 看板娘
    Live2D 看板娘
    Live2D 看板娘
    Live2D 看板娘
    Live2D 看板娘
    [转载]MySQL5.5 配置文件 my.ini 1067错误
  • 原文地址:https://www.cnblogs.com/midworld/p/15647008.html
Copyright © 2011-2022 走看看