zoukankan      html  css  js  c++  java
  • SparkSQL之UDAF使用

    1.创建一个类继承UserDefinedAggregateFunction类。

    ---------------------------------------------------------------------

    package cn.piesat.test

    import org.apache.spark.sql.Row
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types.{DataType, DataTypes, IntegerType, StructType}

    class CountUDAF extends UserDefinedAggregateFunction{
    /**
    * 聚合函数的输入类型
    * @return
    */
    override def inputSchema: StructType = {
    new StructType().add("ageType",IntegerType)
    }

    /**
    * 缓存的数据类型
    * @return
    */
    override def bufferSchema: StructType = {
    new StructType().add("bufferAgeType",IntegerType)
    }

    /**
    * UDAF返回值的类型
    * @return
    */
    override def dataType: DataType = {
    DataTypes.StringType
    }

    /**
    * 如果该函数是确定性的,那么将会返回true,一般给true就行。
    * @return
    */
    override def deterministic: Boolean = true

    /**
    * 为每个分组的数据执行初始化操作
    * @param buffer
    */
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0
    }

    /**
    * 更新操作,指的是每个分组有新的值进来的时候,如何进行分组对应的聚合值的计算
    * @param buffer
    * @param input
    */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    val num= input.getAs[Int](0)
    buffer(0)=buffer.getAs[Int](0)+num
    }

    /**
    * 分区合并时执行的操作
    * @param buffer1
    * @param buffer2
    */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0)=buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
    }

    /**
    * 最后返回的结果
    * @param buffer
    * @return
    */
    override def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0).toString
    }
    }
    --------------------------------------------------------------


    2.在main函数中使用样例
    ---------------------------------------------------------------
    package cn.piesat.test

    import org.apache.spark.sql.SparkSession

    import scala.collection.mutable.ArrayBuffer


    object SparkSQLTest {

    def main(args: Array[String]): Unit = {
    val spark=SparkSession.builder().appName("sparkSql").master("local[4]")
    .config("spark.serializer","org.apache.spark.serializer.KryoSerializer").getOrCreate()
    val sc=spark.sparkContext
    val sqlContext=spark.sqlContext
    val workerRDD=sc.textFile("F://Workers.txt").mapPartitions(itor=>{
    val array=new ArrayBuffer[Worker]()
    while(itor.hasNext){
    val splited=itor.next().split(",")
    array.append(new Worker(splited(0),splited(2).toInt,splited(2)))
    }
    array.toIterator
    })
    import spark.implicits._
    //注册UDAF
    spark.udf.register("countUDF",new CountUDAF())
    val workDS=workerRDD.toDS()
    workDS.createOrReplaceTempView("worker")
    val resultDF=spark.sql("select countUDF(age) from worker")
    val resultDS=resultDF.as("WO")
    resultDS.show()

    spark.stop()

    }
    }
    -----------------------------------------------------------------------------------------------
  • 相关阅读:
    磁盘
    磁盘接口
    Linux help websites
    [SOJ] 1282. Computer games (KMP)
    [SOJ]1753 解码
    hdu 3473 裸的划分树
    hdu 4417 划分树
    hdu 4665 搜索
    hdu 4340 树状DP
    hdu 4005 边连通度与缩点
  • 原文地址:https://www.cnblogs.com/runnerjack/p/10662338.html
Copyright © 2011-2022 走看看