Spark Sql提供了丰富的内置函数让开发者来使用,但实际开发业务场景可能很复杂,内置函数不能够满足业务需求,因此spark sql提供了可扩展的内置函数。
UDF:是普通函数,输入一个或多个参数,返回一个值。比如:len(),isnull()
UDAF:是聚合函数,输入一组值,返回一个聚合结果。比如:max(),avg(),sum()
Spark编写UDF函数
下边的例子是在spark2.0之前的示例:例子中展示只有一个参数输入,和一个参数输出的UDF。
package com.dx.streaming.producer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class TestUDF1 { public static void main(String[] args) { SparkConf sparkConf = new SparkConf(); sparkConf.setMaster("local[2]"); sparkConf.setAppName("spark udf test"); JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf); @SuppressWarnings("deprecation") SQLContext sqlContext=new SQLContext(javaSparkContext); JavaRDD<String> javaRDD = javaSparkContext.parallelize(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu")); JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() { private static final long serialVersionUID = -4769584490875182711L; @Override public Row call(String line) throws Exception { String[] fields = line.split(","); return RowFactory.create(fields); } }); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("id", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> ds = sqlContext.createDataFrame(rowRDD, schema); ds.createOrReplaceTempView("user"); // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx sqlContext.udf().register("strLength", new UDF1<String, Integer>() { private static final long serialVersionUID = -8172995965965931129L; @Override public Integer call(String t1) throws Exception { return t1.length(); } }, DataTypes.IntegerType); Dataset<Row> rows = sqlContext.sql("select id,name,strLength(name) as length from user"); rows.show(); javaSparkContext.stop(); } }
输出效果:
+---+--------+------+
| id| name|length|
+---+--------+------+
| 1|zhangsan| 8|
| 2| lisi| 4|
| 3| wangwu| 6|
| 4| zhaoliu| 7|
+---+--------+------+
上边使用UDF展示了:单个输入,单个输出的函数。那么下边将会展示使用spark2.0实现三个输入,一个输出的UDF函数。
package com.dx.streaming.producer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF3; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class TestUDF2 { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate(); Dataset<String> row = sparkSession.createDataset(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"), Encoders.STRING()); // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx sparkSession.udf().register("strLength", new UDF1<String, Integer>() { private static final long serialVersionUID = -8172995965965931129L; @Override public Integer call(String t1) throws Exception { return t1.length(); } }, DataTypes.IntegerType); sparkSession.udf().register("strConcat", new UDF3<String, String, String, String>() { private static final long serialVersionUID = -8172995965965931129L; @Override public String call(String combChar, String t1, String t2) throws Exception { return t1 + combChar + t2; } }, DataTypes.StringType); showByStruct(sparkSession, row); System.out.println("=========================================="); showBySchema(sparkSession, row); sparkSession.stop(); } private static void showBySchema(SparkSession sparkSession, Dataset<String> row) { JavaRDD<String> javaRDD = row.javaRDD(); JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() { private static final long serialVersionUID = -4769584490875182711L; @Override public Row call(String line) throws Exception { String[] fields = line.split(","); return RowFactory.create(fields); } }); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("id", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema); ds.show(); ds.createOrReplaceTempView("user"); Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('+',id,name) as str from user"); rows.show(); } private static void showByStruct(SparkSession sparkSession, Dataset<String> row) { JavaRDD<Person> map = row.javaRDD().map(Person::parsePerson); Dataset<Row> persons = sparkSession.createDataFrame(map, Person.class); persons.show(); persons.createOrReplaceTempView("user"); Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('-',id,name) as str from user"); rows.show(); } }
Person.java
package com.dx.streaming.producer; import java.io.Serializable; public class Person implements Serializable{ private String id; private String name; public Person(String id, String name) { this.id = id; this.name = name; } public String getId() { return id; } public void setId(String id) { this.id = id; } public String getName() { return name; } public void setName(String name) { this.name = name; } public static Person parsePerson(String line) { String[] fields = line.split(","); Person person = new Person(fields[0], fields[1]); return person; } }
需要注意的地方,我们全局udf函数只需要注册一次,就允许多次调用。
输出效果:
+---+--------+
| id| name|
+---+--------+
| 1|zhangsan|
| 2| lisi|
| 3| wangwu|
| 4| zhaoliu|
+---+--------+
+---+--------+------+----------+
| id| name|length| str|
+---+--------+------+----------+
| 1|zhangsan| 8|1-zhangsan|
| 2| lisi| 4| 2-lisi|
| 3| wangwu| 6| 3-wangwu|
| 4| zhaoliu| 7| 4-zhaoliu|
+---+--------+------+----------+
==========================================
+---+--------+
| id| name|
+---+--------+
| 1|zhangsan|
| 2| lisi|
| 3| wangwu|
| 4| zhaoliu|
+---+--------+
+---+--------+------+----------+
| id| name|length| str|
+---+--------+------+----------+
| 1|zhangsan| 8|1+zhangsan|
| 2| lisi| 4| 2+lisi|
| 3| wangwu| 6| 3+wangwu|
| 4| zhaoliu| 7| 4+zhaoliu|
+---+--------+------+----------+
相信认真阅读的话,通过上边的两个示例,就可以掌握其用法。
Spark编写UDAF函数
自定义聚合函数需要实现UserDefinedAggregateFunction,以下是该抽象类的定义:
package org.apache.spark.sql.expressions import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.types._ import org.apache.spark.annotation.Experimental /** * :: Experimental :: * The base class for implementing user-defined aggregate functions (UDAF). */ @Experimental abstract class UserDefinedAggregateFunction extends Serializable { /** * A [[StructType]] represents data types of input arguments of this aggregate function. * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like * * ``` * new StructType() * .add("doubleInput", DoubleType) * .add("longInput", LongType) * ``` * * The name of a field of this [[StructType]] is only used to identify the corresponding * input argument. Users can choose names to identify the input arguments. */ //输入参数的数据类型定义 def inputSchema: StructType /** * A [[StructType]] represents data types of values in the aggregation buffer. * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], * the returned [[StructType]] will look like * * ``` * new StructType() * .add("doubleInput", DoubleType) * .add("longInput", LongType) * ``` * * The name of a field of this [[StructType]] is only used to identify the corresponding * buffer value. Users can choose names to identify the input arguments. */ //聚合的中间过程中产生的数据的数据类型定义 def bufferSchema: StructType /** * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. */ //聚合结果的数据类型定义 def dataType: DataType /** * Returns true if this function is deterministic, i.e. given the same input, * always return the same output. */ //一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。 def deterministic: Boolean /** * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer. * * The contract should be that applying the merge function on two initial buffers should just * return the initial buffer itself, i.e. * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`. */ //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。 //即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value" def initialize(buffer: MutableAggregationBuffer): Unit /** * Updates the given aggregation buffer `buffer` with new input data from `input`. * * This is called once per input row. */ //用输入数据input更新buffer值,类似于combineByKey def update(buffer: MutableAggregationBuffer, input: Row): Unit /** * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`. * * This is called when we merge two partially aggregated data together. */ //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey //这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。 def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit /** * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given * aggregation buffer. */ //计算并返回最终的聚合结果 def evaluate(buffer: Row): Any /** * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. */ //所有输入数据进行聚合 @scala.annotation.varargs def apply(exprs: Column*): Column = { val aggregateExpression = AggregateExpression2( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = false) Column(aggregateExpression) } /** * Creates a [[Column]] for this UDAF using the distinct values of the given * [[Column]]s as input arguments. */ //所有输入数据去重后进行聚合 @scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = AggregateExpression2( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = true) Column(aggregateExpression) } } /** * :: Experimental :: * A [[Row]] representing an mutable aggregation buffer. * * This is not meant to be extended outside of Spark. */ @Experimental abstract class MutableAggregationBuffer extends Row { /** Update the ith value of this buffer. */ def update(i: Int, value: Any): Unit }
实现单列求平均数的聚合函数:
package com.dx.streaming.producer; import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; public class SimpleAvg extends UserDefinedAggregateFunction { private static final long serialVersionUID = 3924913264741215131L; @Override public StructType inputSchema() { StructType structType= new StructType().add("myinput",DataTypes.DoubleType); return structType; } @Override public StructType bufferSchema() { StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType); return structType; } @Override public DataType dataType() { return DataTypes.DoubleType; } @Override public boolean deterministic() { return true; } //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。 //即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value" @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0 buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0 } /** * partitions内部combine * */ //用输入数据input更新buffer值,类似于combineByKey @Override public void update(MutableAggregationBuffer buffer, Row input) { buffer.update(0, buffer.getLong(0)+1); // 條目數+1 buffer.update(1, buffer.getDouble(1)+input.getDouble(0)); // 输入汇总 } /** * partitions间合并:MutableAggregationBuffer继承自Row。 * */ //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey //这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。 @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0)); // 條目數合併 buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併 } //计算并返回最终的聚合结果 @Override public Object evaluate(Row buffer) { // 计算平均值 Double avg = buffer.getDouble(1) / buffer.getLong(0); Double avgFormat = Double.parseDouble(String.format("%.2f", avg)); return avgFormat; } }
下边展示下如何使用自定义的UDAF函数:
package com.dx.streaming.producer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class TestUDAF1 { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate(); Dataset<String> row = sparkSession.createDataset(Arrays.asList( "1,zhangsan,English,80", "2,zhangsan,History,87", "3,zhangsan,Chinese,88", "4,zhangsan,Chemistry,96", "5,lisi,English,70", "6,lisi,Chinese,74", "7,lisi,History,75", "8,lisi,Chemistry,77", "9,lisi,Physics,79", "10,lisi,Biology,82", "11,wangwu,English,96", "12,wangwu,Chinese,98", "13,wangwu,History,91", "14,zhaoliu,English,68", "15,zhaoliu,Chinese,66"), Encoders.STRING()); JavaRDD<String> javaRDD = row.javaRDD(); JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() { private static final long serialVersionUID = -4769584490875182711L; @Override public Row call(String line) throws Exception { String[] fields = line.split(","); Integer id=Integer.parseInt(fields[0]); String name=fields[1]; String subject=fields[2]; Double achieve=Double.parseDouble(fields[3]); return RowFactory.create(id,name,subject,achieve); } }); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true)); fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("achieve", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema); ds.show(); ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new SimpleAvg(); sparkSession.udf().register("avg_format", udaf); Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve) avg_achieve from user group by name"); rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve) avg_achieve from user group by name"); rows2.show(); } }
输出结果:
+---+--------+---------+-------+
| id| name| subject|achieve|
+---+--------+---------+-------+
| 1|zhangsan| English| 80.0|
| 2|zhangsan| History| 87.0|
| 3|zhangsan| Chinese| 88.0|
| 4|zhangsan|Chemistry| 96.0|
| 5| lisi| English| 70.0|
| 6| lisi| Chinese| 74.0|
| 7| lisi| History| 75.0|
| 8| lisi|Chemistry| 77.0|
| 9| lisi| Physics| 79.0|
| 10| lisi| Biology| 82.0|
| 11| wangwu| English| 96.0|
| 12| wangwu| Chinese| 98.0|
| 13| wangwu| History| 91.0|
| 14| zhaoliu| English| 68.0|
| 15| zhaoliu| Chinese| 66.0|
+---+--------+---------+-------+
+--------+-----------------+
| name| avg_achieve|
+--------+-----------------+
| wangwu| 95.0|
| zhaoliu| 67.0|
|zhangsan| 87.75|
| lisi|76.16666666666667|
+--------+-----------------+
+--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 95.0|
| zhaoliu| 67.0|
|zhangsan| 87.75|
| lisi| 76.17|
+--------+-----------+
实现多列之和,再求平均数的UDAF聚合函数:
package com.dx.streaming.producer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class TestUDAF1 { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate(); Dataset<String> row = sparkSession.createDataset(Arrays.asList( "1,zhangsan,English,80,89", "2,zhangsan,History,87,88", "3,zhangsan,Chinese,88,87", "4,zhangsan,Chemistry,96,95", "5,lisi,English,70,75", "6,lisi,Chinese,74,67", "7,lisi,History,75,80", "8,lisi,Chemistry,77,70", "9,lisi,Physics,79,80", "10,lisi,Biology,82,83", "11,wangwu,English,96,84", "12,wangwu,Chinese,98,64", "13,wangwu,History,91,92", "14,zhaoliu,English,68,80", "15,zhaoliu,Chinese,66,69"), Encoders.STRING()); JavaRDD<String> javaRDD = row.javaRDD(); JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() { private static final long serialVersionUID = -4769584490875182711L; @Override public Row call(String line) throws Exception { String[] fields = line.split(","); Integer id=Integer.parseInt(fields[0]); String name=fields[1]; String subject=fields[2]; Double achieve1=Double.parseDouble(fields[3]); Double achieve2=Double.parseDouble(fields[4]); return RowFactory.create(id,name,subject,achieve1,achieve2); } }); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true)); fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false)); fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema); ds.show(); ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new MutilAvg(2); sparkSession.udf().register("avg_format", udaf); Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve1+achieve2) avg_achieve from user group by name"); rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve1,achieve2) avg_achieve from user group by name"); rows2.show(); } }
上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilAvg实现的就是一个多列求和之后在进行求平均的使用。
MutilAvg.java(udaf函数):
package com.dx.streaming.producer; import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; public class MutilAvg extends UserDefinedAggregateFunction { private static final long serialVersionUID = 3924913264741215131L; private int columnSize=1; public MutilAvg(int columnSize){ this.columnSize=columnSize; } @Override public StructType inputSchema() { StructType structType= new StructType(); for(int i=0;i<columnSize;i++){ structType.add("myinput"+i,DataTypes.DoubleType); } return structType; } @Override public StructType bufferSchema() { StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType); return structType; } @Override public DataType dataType() { return DataTypes.DoubleType; } @Override public boolean deterministic() { return true; } //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。 //即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value" @Override public void initialize(MutableAggregationBuffer buffer) { buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0 buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0 } /** * partitions内部combine * */ //用输入数据input更新buffer值,类似于combineByKey @Override public void update(MutableAggregationBuffer buffer, Row input) { buffer.update(0, buffer.getLong(0)+1); // 條目數+1 // 输入一行包含多列,因此需要把铜一行的多列合并。 Double currentLineSumValue= 0d; for(int i=0;i<columnSize;i++){ currentLineSumValue+=input.getDouble(i); } buffer.update(1, buffer.getDouble(1)+currentLineSumValue); // 输入汇总 } /** * partitions间合并:MutableAggregationBuffer继承自Row。 * */ //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey //这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。 @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0)); // 條目數合併 buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併 } //计算并返回最终的聚合结果 @Override public Object evaluate(Row buffer) { // 计算平均值 Double avg = buffer.getDouble(1) / buffer.getLong(0); Double avgFormat = Double.parseDouble(String.format("%.2f", avg)); return avgFormat; } }
测试输出:
+---+--------+---------+--------+--------+
| id| name| subject|achieve1|achieve2|
+---+--------+---------+--------+--------+
| 1|zhangsan| English| 80.0| 89.0|
| 2|zhangsan| History| 87.0| 88.0|
| 3|zhangsan| Chinese| 88.0| 87.0|
| 4|zhangsan|Chemistry| 96.0| 95.0|
| 5| lisi| English| 70.0| 75.0|
| 6| lisi| Chinese| 74.0| 67.0|
| 7| lisi| History| 75.0| 80.0|
| 8| lisi|Chemistry| 77.0| 70.0|
| 9| lisi| Physics| 79.0| 80.0|
| 10| lisi| Biology| 82.0| 83.0|
| 11| wangwu| English| 96.0| 84.0|
| 12| wangwu| Chinese| 98.0| 64.0|
| 13| wangwu| History| 91.0| 92.0|
| 14| zhaoliu| English| 68.0| 80.0|
| 15| zhaoliu| Chinese| 66.0| 69.0|
+---+--------+---------+--------+--------+
+--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 175.0|
| zhaoliu| 141.5|
|zhangsan| 177.5|
| lisi| 152.0|
+--------+-----------+
+--------+-----------+
| name|avg_achieve|
+--------+-----------+
| wangwu| 175.0|
| zhaoliu| 141.5|
|zhangsan| 177.5|
| lisi| 152.0|
+--------+-----------+
实现多列分别求最大值,之后再从多列中最大值中找出一个最大的值的UDAF聚合函数:
package com.dx.streaming.producer; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class TestUDAF2 { public static void main(String[] args) { SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate(); Dataset<String> row = sparkSession.createDataset(Arrays.asList( "1,zhangsan,English,80,89", "2,zhangsan,History,87,88", "3,zhangsan,Chinese,88,87", "4,zhangsan,Chemistry,96,95", "5,lisi,English,70,75", "6,lisi,Chinese,74,67", "7,lisi,History,75,80", "8,lisi,Chemistry,77,70", "9,lisi,Physics,79,80", "10,lisi,Biology,82,83", "11,wangwu,English,96,84", "12,wangwu,Chinese,98,64", "13,wangwu,History,91,92", "14,zhaoliu,English,68,80", "15,zhaoliu,Chinese,66,69"), Encoders.STRING()); JavaRDD<String> javaRDD = row.javaRDD(); JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() { private static final long serialVersionUID = -4769584490875182711L; @Override public Row call(String line) throws Exception { String[] fields = line.split(","); Integer id=Integer.parseInt(fields[0]); String name=fields[1]; String subject=fields[2]; Double achieve1=Double.parseDouble(fields[3]); Double achieve2=Double.parseDouble(fields[4]); return RowFactory.create(id,name,subject,achieve1,achieve2); } }); List<StructField> fields = new ArrayList<StructField>(); fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true)); fields.add(DataTypes.createStructField("name", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false)); fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false)); StructType schema = DataTypes.createStructType(fields); Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema); ds.show(); ds.createOrReplaceTempView("user"); UserDefinedAggregateFunction udaf=new MutilMax(2,0); sparkSession.udf().register("max_vals", udaf); Dataset<Row> rows1 = sparkSession.sql("" + "select name,max(achieve) as max_achieve " + "from " + "(" + "select name,max(achieve1) achieve from user group by name " + "union all " + "select name,max(achieve2) achieve from user group by name " + ") t10 " + "group by name"); rows1.show(); Dataset<Row> rows2 = sparkSession.sql("select name,max_vals(achieve1,achieve2) as max_achieve from user group by name"); rows2.show(); } }
上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilMax实现的就是一个多列分别求出各自列的最大值,再从这些列的最大值中找出最大的一个值作为返回的最大值。
MutilMax.java(udaf函数):
package com.dx.streaming.producer; import java.util.ArrayList; import java.util.List; import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class MutilMax extends UserDefinedAggregateFunction { private static final long serialVersionUID = 3924913264741215131L; private int columnSize = 1; private Double defaultValue; public MutilMax(int columnSize, double defaultValue) { this.columnSize = columnSize; this.defaultValue = defaultValue; } @Override public StructType inputSchema() { List<StructField> inputFields = new ArrayList<StructField>(); for (int i = 0; i < this.columnSize; i++) { inputFields.add(DataTypes.createStructField("myinput" + i, DataTypes.DoubleType, true)); } StructType inputSchema = DataTypes.createStructType(inputFields); return inputSchema; } @Override public StructType bufferSchema() { List<StructField> bufferFields = new ArrayList<StructField>(); for (int i = 0; i < this.columnSize; i++) { bufferFields.add(DataTypes.createStructField("mymax" + i, DataTypes.DoubleType, true)); } StructType bufferSchema = DataTypes.createStructType(bufferFields); return bufferSchema; } @Override public DataType dataType() { return DataTypes.DoubleType; } @Override public boolean deterministic() { return false; } // 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。 // 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value" @Override public void initialize(MutableAggregationBuffer buffer) { for (int i = 0; i < this.columnSize; i++) { buffer.update(i, 0d); } } /** * partitions内部combine */ // 用输入数据input更新buffer值,类似于combineByKey @Override public void update(MutableAggregationBuffer buffer, Row input) { for (int i = 0; i < this.columnSize; i++) { if( buffer.getDouble(i) >input.getDouble(i)){ buffer.update(i, buffer.getDouble(i)); }else{ buffer.update(i, input.getDouble(i)); } } } /** * partitions间合并:MutableAggregationBuffer继承自Row。 */ // 合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey // 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。 @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { for (int i = 0; i < this.columnSize; i++) { if( buffer1.getDouble(i) >buffer2.getDouble(i)){ buffer1.update(i, buffer1.getDouble(i)); }else{ buffer1.update(i, buffer2.getDouble(i)); } } } // 计算并返回最终的聚合结果 @Override public Object evaluate(Row buffer) { // 计算平均值 Double max = Double.MIN_VALUE; for (int i = 0; i < this.columnSize; i++) { if (buffer.getDouble(i) > max) { max = buffer.getDouble(i); } } if (max == Double.MIN_VALUE) { max = this.defaultValue; } return max; } }
打印结果:
+---+--------+---------+--------+--------+
| id| name| subject|achieve1|achieve2|
+---+--------+---------+--------+--------+
| 1|zhangsan| English| 80.0| 89.0|
| 2|zhangsan| History| 87.0| 88.0|
| 3|zhangsan| Chinese| 88.0| 87.0|
| 4|zhangsan|Chemistry| 96.0| 95.0|
| 5| lisi| English| 70.0| 75.0|
| 6| lisi| Chinese| 74.0| 67.0|
| 7| lisi| History| 75.0| 80.0|
| 8| lisi|Chemistry| 77.0| 70.0|
| 9| lisi| Physics| 79.0| 80.0|
| 10| lisi| Biology| 82.0| 83.0|
| 11| wangwu| English| 96.0| 84.0|
| 12| wangwu| Chinese| 98.0| 64.0|
| 13| wangwu| History| 91.0| 92.0|
| 14| zhaoliu| English| 68.0| 80.0|
| 15| zhaoliu| Chinese| 66.0| 69.0|
+---+--------+---------+--------+--------+
+--------+-----------+
| name|max_achieve|
+--------+-----------+
| wangwu| 98.0|
| zhaoliu| 80.0|
|zhangsan| 96.0|
| lisi| 83.0|
+--------+-----------+
+--------+-----------+
| name|max_achieve|
+--------+-----------+
| wangwu| 98.0|
| zhaoliu| 80.0|
|zhangsan| 96.0|
| lisi| 83.0|
+--------+-----------+
Spark编写Agg函数
实现一个avg函数:
第一步:定义一个Average,用来存储count,sum;
import java.io.Serializable; public class Average implements Serializable { private long sum; private long count; // Constructors, getters, setters... public long getSum() { return sum; } public void setSum(long sum) { this.sum = sum; } public long getCount() { return count; } public void setCount(long count) { this.count = count; } public Average() { } public Average(long sum, long count) { this.sum = sum; this.count = count; } }
第二步:定义一个Employee,存储员工信息:员工名称、员工薪资;
import java.io.Serializable; public class Employee implements Serializable { private String name; private long salary; // Constructors, getters, setters... public String getName() { return name; } public void setName(String name) { this.name = name; } public long getSalary() { return salary; } public void setSalary(long salary) { this.salary = salary; } public Employee() { } public Employee(String name, long salary) { this.name = name; this.salary = salary; } }
第三步:定义一个Agg,实现对员工的薪资avg功能;
import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.expressions.Aggregator; public class MyAverage extends Aggregator<Employee, Average, Double> { // A zero value for this aggregation. Should satisfy the property that any b + zero = b @Override public Average zero() { return new Average(0L, 0L); } // Combine two values to produce a new value. For performance, the function may modify `buffer` // and return it instead of constructing a new object @Override public Average reduce(Average buffer, Employee employee) { long newSum = buffer.getSum() + employee.getSalary(); long newCount = buffer.getCount() + 1; buffer.setSum(newSum); buffer.setCount(newCount); return buffer; } // Merge two intermediate values @Override public Average merge(Average b1, Average b2) { long mergedSum = b1.getSum() + b2.getSum(); long mergedCount = b1.getCount() + b2.getCount(); b1.setSum(mergedSum); b1.setCount(mergedCount); return b1; } // Transform the output of the reduction @Override public Double finish(Average reduction) { return ((double) reduction.getSum()) / reduction.getCount(); } // Specifies the Encoder for the intermediate value type @Override public Encoder<Average> bufferEncoder() { return Encoders.bean(Average.class); } // Specifies the Encoder for the final output value type @Override public Encoder<Double> outputEncoder() { return Encoders.DOUBLE(); } }
第四步:spark调用agg,验证。
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.*;; import java.util.ArrayList; import java.util.List; public class SparkClient { public static void main(String[] args) { final SparkSession spark = SparkSession.builder().master("local[*]").appName("test_agg").getOrCreate(); final JavaSparkContext ctx = JavaSparkContext.fromSparkContext(spark.sparkContext()); List<Employee> employeeList = new ArrayList<Employee>(); employeeList.add(new Employee("Michael", 3000L)); employeeList.add(new Employee("Andy", 4500L)); employeeList.add(new Employee("Justin", 3500L)); employeeList.add(new Employee("Berta", 4000L)); JavaRDD<Employee> rows = ctx.parallelize(employeeList); Dataset<Employee> ds = spark.createDataFrame(rows, Employee.class).map(new MapFunction<Row, Employee>() { @Override public Employee call(Row row) throws Exception { return new Employee(row.getString(0), row.getLong(1)); } }, Encoders.bean(Employee.class)); ds.show(); // +-------+------+ // | name|salary| // +-------+------+ // |Michael| 3000| // | Andy| 4500| // | Justin| 3500| // | Berta| 4000| // +-------+------+ MyAverage myAverage = new MyAverage(); // Convert the function to a `TypedColumn` and give it a name TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary"); Dataset<Double> result = ds.select(averageSalary); result.show(); // +--------------+ // |average_salary| // +--------------+ // | 3750.0| // +--------------+ } }
输出:
+-------+------+ | name|salary| +-------+------+ |Michael| 3000| | Andy| 4500| | Justin| 3500| | Berta| 4000| +-------+------+ +--------------+ |average_salary| +--------------+ | 3750.0| +--------------+
参考:
https://www.cnblogs.com/LHWorldBlog/p/8432210.html
https://blog.csdn.net/kwu_ganymede/article/details/50462020
https://my.oschina.net/cloudcoder/blog/640009
https://blog.csdn.net/xgjianstart/article/details/54956413