zoukankan      html  css  js  c++  java
  • UDAF(用户自定义聚合函数)求众数

    除了逐行处理数据的udf,还有比较常见的就是聚合多行处理udaf,自定义聚合函数。类比rdd编程就是map和reduce算子的区别。
    自定义UDAF,需要extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction,并实现接口中的8个方法。
    udaf写起来比较麻烦,我下面列一个之前写的取众数聚合函数,在我们通常在聚合统计的时候可能会受某条脏数据的影响。
    举个栗子:
    对于一个app日志聚合的时候,有id与ip,原则上一个id有一个ip,但是在多条数据里有一条ip是错误的或者为空的,这时候group能会聚合成两条数据了就,如果使用max,min对ip也进行聚合,那也不太合理,这时候可以进行投票,去类似多数对结果,从而聚合后只有一个设备。
    废话少说,上代码:
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    
    /**
      * Description: 自定义聚合函数:众数(取列内频率最高的一条)
      */
    
    class UDAFGetMode extends UserDefinedAggregateFunction {
      override def inputSchema: StructType = {
        StructType(StructField("inputStr", StringType, true) :: Nil)
      }
    
    
      override def bufferSchema: StructType = {
        StructType(StructField("bufferMap", MapType(keyType = StringType, valueType = IntegerType), true) :: Nil)
      }
    
      override def dataType: DataType = StringType
    
      override def deterministic: Boolean = false
    
      //初始化map
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        buffer(0) = scala.collection.immutable.Map[String, Int]()
      }
    
      //如果包含这个key则value+1,否则写入key,value=1
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        val key = input.getAs[String](0)
        val immap = buffer.getAs[Map[String, Int]](0)
        val bufferMap = scala.collection.mutable.Map[String, Int](immap.toSeq: _*)
        val ret = if (bufferMap.contains(key)) {
          //      val new_value = bufferMap.get(key).get + 1
          val new_value = bufferMap(key) + 1
          bufferMap.put(key, new_value)
          bufferMap
        } else {
          bufferMap.put(key, 1)
          bufferMap
        }
        buffer.update(0, ret)
    
      }
    
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        //合并两个map 相同的key的value累加
        val tempMap = (buffer1.getAs[Map[String, Int]](0) /: buffer2.getAs[Map[String, Int]](0)) {
          case (map, (k, v)) => map + (k -> (v + map.getOrElse(k, 0)))
        }
        buffer1.update(0, tempMap)
      }
    
      override def evaluate(buffer: Row): Any = {
        //返回值最大的key
        var max_value = 0
        var max_key = ""
        buffer.getAs[Map[String, Int]](0).foreach({ x =>
          val key = x._1
          val value = x._2
          if (value > max_value) {
            max_value = value
            max_key = key
          }
        })
        max_key
      }
    }

    测试类:

    object UDAFTest {
      def main(args: Array[String]): Unit = {
        val spark = SparkSession.builder().master("local").appName(this.getClass.getSimpleName).getOrCreate()
        spark.udf.register("get_mode", new UDAFGetMode)
        import spark.implicits._
        val df = Seq(
          (1, "10.10.1.1", "start"),
          (1, "10.10.1.1", "search"),
          (2, "123.123.123.1", "search"),
          (1, "10.10.1.0", "stop"),
          (2, "123.123.123.1", "start")
        ).toDF("id", "ip", "action")
    
        df.createTempView("tb")
        spark.sql(s"select id,get_mode(ip) as u_ip,count(*) as cnt from tb group by id").show()
      }
    }
  • 相关阅读:
    Powershell 音乐播放
    Powershell指令集_2
    Powershell指令集_2
    Powershell 邮件发送
    Powershell 邮件发送
    Oracle 11g 关闭内存自动管理
    Oracle 11g 内存手动管理
    Oracle 内存参数调整
    RESTful三理解
    RESTful三理解
  • 原文地址:https://www.cnblogs.com/itboys/p/10626310.html
Copyright © 2011-2022 走看看