表值聚合函数
自定义表值聚合函数(UDTAGG)可以把一个表(一行或者多行,每行有一列或者多列)聚合成另一张表,结果中可以有多行多列。
上图展示了一个表值聚合函数的例子。假设你有一个饮料的表,这个表有 3 列,分别是 id
、name
和 price
,一共有 5 行。假设你需要找到价格最高的两个饮料,类似于 top2()
表值聚合函数。你需要遍历所有 5 行数据,结果是有 2 行数据的一个表。
用户自定义表值聚合函数是通过扩展 TableAggregateFunction
类来实现的。一个 TableAggregateFunction
的工作过程如下。首先,它需要一个 accumulator
,这个 accumulator
负责存储聚合的中间结果。 通过调用 TableAggregateFunction
的 createAccumulator
方法来构造一个空的 accumulator。接下来,对于每一行数据,会调用 accumulate
方法来更新 accumulator。当所有数据都处理完之后,调用 emitValue
方法来计算和返回最终的结果。
下面几个 TableAggregateFunction
的方法是必须要实现的:
createAccumulator()
accumulate()
Flink 的类型推导在遇到复杂类型的时候可能会推导出错误的结果,比如那些非基本类型和普通的 POJO 类型的复杂类型。所以类似于 ScalarFunction
和 TableFunction
,TableAggregateFunction
也提供了 TableAggregateFunction#getResultType()
和 TableAggregateFunction#getAccumulatorType()
方法来指定返回值类型和 accumulator 的类型,这两个方法都需要返回 TypeInformation
。
除了上面的方法,还有几个其他的方法可以选择性的实现。有些方法可以让查询更加高效,而有些方法对于某些特定场景是必须要实现的。比如,在会话窗口(当两个会话窗口合并时会合并两个 accumulator)中使用聚合函数时,必须要实现merge()
方法。
下面几个 TableAggregateFunction
的方法在某些特定场景下是必须要实现的:
retract()
在 boundedOVER
窗口中的聚合函数必须要实现。merge()
在许多批式聚合和会话窗口聚合中是必须要实现的。resetAccumulator()
在许多批式聚合中是必须要实现的。emitValue()
在批式聚合以及窗口聚合中是必须要实现的。
下面的 TableAggregateFunction
的方法可以提升流式任务的效率:
emitUpdateWithRetract()
在 retract 模式下,该方法负责发送被更新的值。
emitValue
方法会发送所有 accumulator 给出的结果。拿 TopN 来说,emitValue
每次都会发送所有的最大的 n 个值。这在流式任务中可能会有一些性能问题。为了提升性能,用户可以实现 emitUpdateWithRetract
方法。这个方法在 retract 模式下会增量的输出结果,比如有数据更新了,我们必须要撤回老的数据,然后再发送新的数据。如果定义了 emitUpdateWithRetract
方法,那它会优先于 emitValue
方法被使用,因为一般认为 emitUpdateWithRetract
会更加高效,因为它的输出是增量的。
TableAggregateFunction
的所有方法都必须是 public
的、非 static
的,而且名字必须跟上面提到的一样。createAccumulator
、getResultType
和 getAccumulatorType
这三个方法是在抽象父类 TableAggregateFunction
中定义的,而其他的方法都是约定的方法。要实现一个表值聚合函数,你必须扩展 org.apache.flink.table.functions.TableAggregateFunction
,并且实现一个(或者多个)accumulate
方法。accumulate
方法可以有多个重载的方法,也可以支持变长参数。
TableAggregateFunction
的所有方法的详细文档如下。
/** * Base class for user-defined aggregates and table aggregates. * * @tparam T the type of the aggregation result. * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the * aggregated values which are needed to compute an aggregation result. */ abstract class UserDefinedAggregateFunction[T, ACC] extends UserDefinedFunction { /** * Creates and init the Accumulator for this (table)aggregate function. * * @return the accumulator with the initial value */ def createAccumulator(): ACC // MANDATORY /** * Returns the TypeInformation of the (table)aggregate function's result. * * @return The TypeInformation of the (table)aggregate function's result or null if the result * type should be automatically inferred. */ def getResultType: TypeInformation[T] = null // PRE-DEFINED /** * Returns the TypeInformation of the (table)aggregate function's accumulator. * * @return The TypeInformation of the (table)aggregate function's accumulator or null if the * accumulator type should be automatically inferred. */ def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED } /** * Base class for table aggregation functions. * * @tparam T the type of the aggregation result * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the * aggregated values which are needed to compute an aggregation result. * TableAggregateFunction represents its state using accumulator, thereby the state of * the TableAggregateFunction must be put into the accumulator. */ abstract class TableAggregateFunction[T, ACC] extends UserDefinedAggregateFunction[T, ACC] { /** * Processes the input values and update the provided accumulator instance. The method * accumulate can be overloaded with different custom types and arguments. A TableAggregateFunction * requires at least one accumulate() method. * * @param accumulator the accumulator which contains the current aggregated results * @param [user defined inputs] the input value (usually obtained from a new arrived data). */ def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY /** * Retracts the input values from the accumulator instance. The current design assumes the * inputs are the values that have been previously accumulated. The method retract can be * overloaded with different custom types and arguments. This function must be implemented for * datastream bounded over aggregate. * * @param accumulator the accumulator which contains the current aggregated results * @param [user defined inputs] the input value (usually obtained from a new arrived data). */ def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL /** * Merges a group of accumulator instances into one accumulator instance. This function must be * implemented for datastream session window grouping aggregate and dataset grouping aggregate. * * @param accumulator the accumulator which will keep the merged aggregate results. It should * be noted that the accumulator may contain the previous aggregated * results. Therefore user should not replace or clean this instance in the * custom merge method. * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be * merged. */ def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL /** * Called every time when an aggregation result should be materialized. The returned value * could be either an early and incomplete result (periodically emitted as data arrive) or * the final result of the aggregation. * * @param accumulator the accumulator which contains the current * aggregated results * @param out the collector used to output data */ def emitValue(accumulator: ACC, out: Collector[T]): Unit // OPTIONAL /** * Called every time when an aggregation result should be materialized. The returned value * could be either an early and incomplete result (periodically emitted as data arrive) or * the final result of the aggregation. * * Different from emitValue, emitUpdateWithRetract is used to emit values that have been updated. * This method outputs data incrementally in retract mode, i.e., once there is an update, we * have to retract old records before sending new updated ones. The emitUpdateWithRetract * method will be used in preference to the emitValue method if both methods are defined in the * table aggregate function, because the method is treated to be more efficient than emitValue * as it can outputvalues incrementally. * * @param accumulator the accumulator which contains the current * aggregated results * @param out the retractable collector used to output data. Use collect method * to output(add) records and use retract method to retract(delete) * records. */ def emitUpdateWithRetract(accumulator: ACC, out: RetractableCollector[T]): Unit // OPTIONAL /** * Collects a record and forwards it. The collector can output retract messages with the retract * method. Note: only use it in `emitRetractValueIncrementally`. */ trait RetractableCollector[T] extends Collector[T] { /** * Retract a record. * * @param record The record to retract. */ def retract(record: T): Unit } }
下面的例子展示了如何
- 定义一个
TableAggregateFunction
来计算给定列的最大的 2 个值, - 在
TableEnvironment
中注册函数, - 在 Table API 查询中使用函数(当前只在 Table API 中支持 TableAggregateFunction)。
为了计算最大的 2 个值,accumulator 需要保存当前看到的最大的 2 个值。在我们的例子中,我们定义了类 Top2Accum
来作为 accumulator。Flink 的 checkpoint 机制会自动保存 accumulator,并且在失败时进行恢复,来保证精确一次的语义。
我们的 Top2
表值聚合函数(TableAggregateFunction
)的 accumulate()
方法有两个输入,第一个是 Top2Accum
accumulator,另一个是用户定义的输入:输入的值 v
。尽管 merge()
方法在大多数聚合类型中不是必须的,我们也在样例中提供了它的实现。请注意,我们在 Scala 样例中也使用的是 Java 的基础类型,并且定义了 getResultType()
和 getAccumulatorType()
方法,因为 Flink 的类型推导对于 Scala 的类型推导支持的不是很好。
import java.lang.{Integer => JInteger} import org.apache.flink.table.api.Types import org.apache.flink.table.functions.TableAggregateFunction /** * Accumulator for top2. */ class Top2Accum { var first: JInteger = _ var second: JInteger = _ } /** * The top2 user-defined table aggregate function. */ class Top2 extends TableAggregateFunction[JTuple2[JInteger, JInteger], Top2Accum] { override def createAccumulator(): Top2Accum = { val acc = new Top2Accum acc.first = Int.MinValue acc.second = Int.MinValue acc } def accumulate(acc: Top2Accum, v: Int) { if (v > acc.first) { acc.second = acc.first acc.first = v } else if (v > acc.second) { acc.second = v } } def merge(acc: Top2Accum, its: JIterable[Top2Accum]): Unit = { val iter = its.iterator() while (iter.hasNext) { val top2 = iter.next() accumulate(acc, top2.first) accumulate(acc, top2.second) } } def emitValue(acc: Top2Accum, out: Collector[JTuple2[JInteger, JInteger]]): Unit = { // emit the value and rank if (acc.first != Int.MinValue) { out.collect(JTuple2.of(acc.first, 1)) } if (acc.second != Int.MinValue) { out.collect(JTuple2.of(acc.second, 2)) } } } // 初始化表 val tab = ... // 使用函数 tab .groupBy('key) .flatAggregate(top2('a) as ('v, 'rank)) .select('key, 'v, 'rank)
下面的例子展示了如何使用 emitUpdateWithRetract
方法来只发送更新的数据。为了只发送更新的结果,accumulator 保存了上一次的最大的2个值,也保存了当前最大的2个值。注意:如果 TopN 中的 n 非常大,这种既保存上次的结果,也保存当前的结果的方式不太高效。一种解决这种问题的方式是把输入数据直接存储到 accumulator
中,然后在调用 emitUpdateWithRetract
方法时再进行计算。
import java.lang.{Integer => JInteger} import org.apache.flink.table.api.Types import org.apache.flink.table.functions.TableAggregateFunction /** * Accumulator for top2. */ class Top2Accum { var first: JInteger = _ var second: JInteger = _ var oldFirst: JInteger = _ var oldSecond: JInteger = _ } /** * The top2 user-defined table aggregate function. */ class Top2 extends TableAggregateFunction[JTuple2[JInteger, JInteger], Top2Accum] { override def createAccumulator(): Top2Accum = { val acc = new Top2Accum acc.first = Int.MinValue acc.second = Int.MinValue acc.oldFirst = Int.MinValue acc.oldSecond = Int.MinValue acc } def accumulate(acc: Top2Accum, v: Int) { if (v > acc.first) { acc.second = acc.first acc.first = v } else if (v > acc.second) { acc.second = v } } def emitUpdateWithRetract( acc: Top2Accum, out: RetractableCollector[JTuple2[JInteger, JInteger]]) : Unit = { if (acc.first != acc.oldFirst) { // if there is an update, retract old value then emit new value. if (acc.oldFirst != Int.MinValue) { out.retract(JTuple2.of(acc.oldFirst, 1)) } out.collect(JTuple2.of(acc.first, 1)) acc.oldFirst = acc.first } if (acc.second != acc.oldSecond) { // if there is an update, retract old value then emit new value. if (acc.oldSecond != Int.MinValue) { out.retract(JTuple2.of(acc.oldSecond, 2)) } out.collect(JTuple2.of(acc.second, 2)) acc.oldSecond = acc.second } } } // 初始化表 val tab = ... // 使用函数 tab .groupBy('key) .flatAggregate(top2('a) as ('v, 'rank)) .select('key, 'v, 'rank)