zoukankan      html  css  js  c++  java
  • Kafka:ZK+Kafka+Spark Streaming集群环境搭建(十五)Spark编写UDF、UDAF、Agg函数

    Spark Sql提供了丰富的内置函数让开发者来使用,但实际开发业务场景可能很复杂,内置函数不能够满足业务需求,因此spark sql提供了可扩展的内置函数。

    UDF:是普通函数,输入一个或多个参数,返回一个值。比如:len(),isnull()

    UDAF:是聚合函数,输入一组值,返回一个聚合结果。比如:max(),avg(),sum()

    Spark编写UDF函数

    下边的例子是在spark2.0之前的示例:例子中展示只有一个参数输入,和一个参数输出的UDF。

    package com.dx.streaming.producer;
    
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    
    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SQLContext;
    import org.apache.spark.sql.api.java.UDF1;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    public class TestUDF1 {
        public static void main(String[] args) {        
            SparkConf sparkConf = new SparkConf();
            sparkConf.setMaster("local[2]");
            sparkConf.setAppName("spark udf test");
            JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
            @SuppressWarnings("deprecation")
            SQLContext sqlContext=new SQLContext(javaSparkContext);
            JavaRDD<String> javaRDD = javaSparkContext.parallelize(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"));
            JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
                private static final long serialVersionUID = -4769584490875182711L;
    
                @Override
                public Row call(String line) throws Exception {
                    String[] fields = line.split(",");
                    return RowFactory.create(fields);
                }
            });
    
            List<StructField> fields = new ArrayList<StructField>();
            fields.add(DataTypes.createStructField("id", DataTypes.StringType, true));
            fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
    
            StructType schema = DataTypes.createStructType(fields);
            Dataset<Row> ds = sqlContext.createDataFrame(rowRDD, schema);
            ds.createOrReplaceTempView("user");
    
            // 根据UDF函数参数的个数来决定是实现哪一个UDF  UDF1,UDF2。。。。UDF1xxx
            sqlContext.udf().register("strLength", new UDF1<String, Integer>() {
                private static final long serialVersionUID = -8172995965965931129L;
    
                @Override
                public Integer call(String t1) throws Exception {
                    return t1.length();
                }
            }, DataTypes.IntegerType);
    
            Dataset<Row> rows = sqlContext.sql("select id,name,strLength(name) as length from user");
            rows.show();
    
            javaSparkContext.stop();
        }
    }

    输出效果:

    +---+--------+------+
    | id|    name|length|
    +---+--------+------+
    |  1|zhangsan|     8|
    |  2|    lisi|     4|
    |  3|  wangwu|     6|
    |  4| zhaoliu|     7|
    +---+--------+------+

    上边使用UDF展示了:单个输入,单个输出的函数。那么下边将会展示使用spark2.0实现三个输入,一个输出的UDF函数。

    package com.dx.streaming.producer;
    
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Encoders;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SparkSession;
    import org.apache.spark.sql.api.java.UDF1;
    import org.apache.spark.sql.api.java.UDF3;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    public class TestUDF2 {
        public static void main(String[] args) {
            SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
            Dataset<String> row = sparkSession.createDataset(Arrays.asList("1,zhangsan", "2,lisi", "3,wangwu", "4,zhaoliu"), Encoders.STRING());
    
            // 根据UDF函数参数的个数来决定是实现哪一个UDF UDF1,UDF2。。。。UDF1xxx
            sparkSession.udf().register("strLength", new UDF1<String, Integer>() {
                private static final long serialVersionUID = -8172995965965931129L;
    
                @Override
                public Integer call(String t1) throws Exception {
                    return t1.length();
                }
            }, DataTypes.IntegerType);
            sparkSession.udf().register("strConcat", new UDF3<String, String, String, String>() {
                private static final long serialVersionUID = -8172995965965931129L;
    
                @Override
                public String call(String combChar, String t1, String t2) throws Exception {
                    return t1 + combChar + t2;
                }
            }, DataTypes.StringType);
    
            showByStruct(sparkSession, row);
            System.out.println("==========================================");
            showBySchema(sparkSession, row);
    
            sparkSession.stop();
        }
    
        private static void showBySchema(SparkSession sparkSession, Dataset<String> row) {
            JavaRDD<String> javaRDD = row.javaRDD();
            JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
                private static final long serialVersionUID = -4769584490875182711L;
    
                @Override
                public Row call(String line) throws Exception {
                    String[] fields = line.split(",");
                    return RowFactory.create(fields);
                }
            });
    
            List<StructField> fields = new ArrayList<StructField>();
            fields.add(DataTypes.createStructField("id", DataTypes.StringType, true));
            fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
    
            StructType schema = DataTypes.createStructType(fields);
            Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
            ds.show();
            ds.createOrReplaceTempView("user");
    
            Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('+',id,name) as str from user");
            rows.show();
        }
    
        private static void showByStruct(SparkSession sparkSession, Dataset<String> row) {
            JavaRDD<Person> map = row.javaRDD().map(Person::parsePerson);
            Dataset<Row> persons = sparkSession.createDataFrame(map, Person.class);
            persons.show();
    
            persons.createOrReplaceTempView("user");
    
            Dataset<Row> rows = sparkSession.sql("select id,name,strLength(name) as length,strConcat('-',id,name) as str from user");
            rows.show();
        }
    }

    Person.java

    package com.dx.streaming.producer;
    
    import java.io.Serializable;
    
    public class Person implements Serializable{
        private String id;
        private String name;
    
        public Person(String id, String name) {
            this.id = id;
            this.name = name;
        }
    
        public String getId() {
            return id;
        }
    
        public void setId(String id) {
            this.id = id;
        }
    
        public String getName() {
            return name;
        }
    
        public void setName(String name) {
            this.name = name;
        }
        
        public static Person parsePerson(String line)  {
            String[] fields = line.split(",");
            Person person = new Person(fields[0], fields[1]);
            return person;
        }
    }
    View Code

    需要注意的地方,我们全局udf函数只需要注册一次,就允许多次调用。

    输出效果:

    +---+--------+
    | id|    name|
    +---+--------+
    |  1|zhangsan|
    |  2|    lisi|
    |  3|  wangwu|
    |  4| zhaoliu|
    +---+--------+
    
    +---+--------+------+----------+
    | id|    name|length|       str|
    +---+--------+------+----------+
    |  1|zhangsan|     8|1-zhangsan|
    |  2|    lisi|     4|    2-lisi|
    |  3|  wangwu|     6|  3-wangwu|
    |  4| zhaoliu|     7| 4-zhaoliu|
    +---+--------+------+----------+
    
    ==========================================
    
    +---+--------+
    | id|    name|
    +---+--------+
    |  1|zhangsan|
    |  2|    lisi|
    |  3|  wangwu|
    |  4| zhaoliu|
    +---+--------+
    
    +---+--------+------+----------+
    | id|    name|length|       str|
    +---+--------+------+----------+
    |  1|zhangsan|     8|1+zhangsan|
    |  2|    lisi|     4|    2+lisi|
    |  3|  wangwu|     6|  3+wangwu|
    |  4| zhaoliu|     7| 4+zhaoliu|
    +---+--------+------+----------+

    相信认真阅读的话,通过上边的两个示例,就可以掌握其用法。

    Spark编写UDAF函数

    自定义聚合函数需要实现UserDefinedAggregateFunction,以下是该抽象类的定义:

    package org.apache.spark.sql.expressions
     
    import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2}
    import org.apache.spark.sql.execution.aggregate.ScalaUDAF
    import org.apache.spark.sql.{Column, Row}
    import org.apache.spark.sql.types._
    import org.apache.spark.annotation.Experimental
     
    /**
     * :: Experimental ::
     * The base class for implementing user-defined aggregate functions (UDAF).
     */
    @Experimental
    abstract class UserDefinedAggregateFunction extends Serializable {
     
      /**
       * A [[StructType]] represents data types of input arguments of this aggregate function.
       * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments
       * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like
       *
       * ```
       *   new StructType()
       *    .add("doubleInput", DoubleType)
       *    .add("longInput", LongType)
       * ```
       *
       * The name of a field of this [[StructType]] is only used to identify the corresponding
       * input argument. Users can choose names to identify the input arguments.
       */
       //输入参数的数据类型定义
      def inputSchema: StructType
     
      /**
       * A [[StructType]] represents data types of values in the aggregation buffer.
       * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values
       * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]],
       * the returned [[StructType]] will look like
       *
       * ```
       *   new StructType()
       *    .add("doubleInput", DoubleType)
       *    .add("longInput", LongType)
       * ```
       *
       * The name of a field of this [[StructType]] is only used to identify the corresponding
       * buffer value. Users can choose names to identify the input arguments.
       */
       //聚合的中间过程中产生的数据的数据类型定义
      def bufferSchema: StructType
     
      /**
       * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]].
       */
       //聚合结果的数据类型定义
      def dataType: DataType
     
      /**
       * Returns true if this function is deterministic, i.e. given the same input,
       * always return the same output.
       */
       //一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的。
      def deterministic: Boolean
     
      /**
       * Initializes the given aggregation buffer, i.e. the zero value of the aggregation buffer.
       *
       * The contract should be that applying the merge function on two initial buffers should just
       * return the initial buffer itself, i.e.
       * `merge(initialBuffer, initialBuffer)` should equal `initialBuffer`.
       */
       //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
       //即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
      def initialize(buffer: MutableAggregationBuffer): Unit
      /**
       * Updates the given aggregation buffer `buffer` with new input data from `input`.
       *
       * This is called once per input row.
       */
       //用输入数据input更新buffer值,类似于combineByKey
      def update(buffer: MutableAggregationBuffer, input: Row): Unit
      /**
       * Merges two aggregation buffers and stores the updated buffer values back to `buffer1`.
       *
       * This is called when we merge two partially aggregated data together.
       */
       //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
       //这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
      def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
      /**
       * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given
       * aggregation buffer.
       */
       //计算并返回最终的聚合结果
      def evaluate(buffer: Row): Any
      /**
       * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
       */
       //所有输入数据进行聚合
      @scala.annotation.varargs
      def apply(exprs: Column*): Column = {
        val aggregateExpression =
          AggregateExpression2(
            ScalaUDAF(exprs.map(_.expr), this),
            Complete,
            isDistinct = false)
        Column(aggregateExpression)
      }
     
      /**
       * Creates a [[Column]] for this UDAF using the distinct values of the given
       * [[Column]]s as input arguments.
       */
       //所有输入数据去重后进行聚合
      @scala.annotation.varargs
      def distinct(exprs: Column*): Column = {
        val aggregateExpression =
          AggregateExpression2(
            ScalaUDAF(exprs.map(_.expr), this),
            Complete,
            isDistinct = true)
        Column(aggregateExpression)
      }
    }
     
    /**
     * :: Experimental ::
     * A [[Row]] representing an mutable aggregation buffer.
     *
     * This is not meant to be extended outside of Spark.
     */
    @Experimental
    abstract class MutableAggregationBuffer extends Row {
     
      /** Update the ith value of this buffer. */
      def update(i: Int, value: Any): Unit
    }

    实现单列求平均数的聚合函数:

    package com.dx.streaming.producer;
    
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.expressions.MutableAggregationBuffer;
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
    import org.apache.spark.sql.types.DataType;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructType;
    
    public class SimpleAvg extends UserDefinedAggregateFunction {
        private static final long serialVersionUID = 3924913264741215131L;
    
        @Override
        public StructType inputSchema() {
            StructType structType=     new StructType().add("myinput",DataTypes.DoubleType);
            return structType;
        }
        
        
        @Override
        public StructType bufferSchema() {
            StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType);
            return structType;
        }
    
        @Override
        public DataType dataType() {        
            return DataTypes.DoubleType;
        }
    
        @Override
        public boolean deterministic() {
            return true;
        }
    
       //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
       //即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
        @Override
        public void initialize(MutableAggregationBuffer buffer) {
            buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0
            buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0
        }
    
        /**
         * partitions内部combine
         * */
        //用输入数据input更新buffer值,类似于combineByKey
        @Override
        public void update(MutableAggregationBuffer buffer, Row input) {
            buffer.update(0, buffer.getLong(0)+1);                     // 條目數+1
            buffer.update(1, buffer.getDouble(1)+input.getDouble(0)); // 输入汇总
        }
    
        /**
         * partitions间合并:MutableAggregationBuffer继承自Row。
         * */   
        //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
        //这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。    
        @Override
        public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
            buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0));     // 條目數合併
            buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併
        }
        
        //计算并返回最终的聚合结果
        @Override
        public Object evaluate(Row buffer) {
            // 计算平均值
            Double avg = buffer.getDouble(1) / buffer.getLong(0);
            Double avgFormat = Double.parseDouble(String.format("%.2f", avg));
    
            return avgFormat;
        }
    }

    下边展示下如何使用自定义的UDAF函数:

    package com.dx.streaming.producer;
    
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Encoders;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SparkSession;
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    public class TestUDAF1 {
    
        public static void main(String[] args) {
            SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
            Dataset<String> row = sparkSession.createDataset(Arrays.asList(
                    "1,zhangsan,English,80",
                    "2,zhangsan,History,87",
                    "3,zhangsan,Chinese,88",
                    "4,zhangsan,Chemistry,96",
                    "5,lisi,English,70",
                    "6,lisi,Chinese,74",
                    "7,lisi,History,75",
                    "8,lisi,Chemistry,77",
                    "9,lisi,Physics,79",
                    "10,lisi,Biology,82",
                    "11,wangwu,English,96",
                    "12,wangwu,Chinese,98",
                    "13,wangwu,History,91",
                    "14,zhaoliu,English,68",
                    "15,zhaoliu,Chinese,66"), Encoders.STRING());
            JavaRDD<String> javaRDD = row.javaRDD();
            JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
                private static final long serialVersionUID = -4769584490875182711L;
    
                @Override
                public Row call(String line) throws Exception {
                    String[] fields = line.split(",");
                    Integer id=Integer.parseInt(fields[0]);
                    String name=fields[1];
                    String subject=fields[2];
                    Double achieve=Double.parseDouble(fields[3]);
                    return RowFactory.create(id,name,subject,achieve);
                }
            });
    
            List<StructField> fields = new ArrayList<StructField>();
            fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
            fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
            fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
            fields.add(DataTypes.createStructField("achieve", DataTypes.DoubleType, false));
    
            StructType schema = DataTypes.createStructType(fields);
            Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
            ds.show();
            ds.createOrReplaceTempView("user");
    
            UserDefinedAggregateFunction udaf=new SimpleAvg();
            sparkSession.udf().register("avg_format", udaf);
            
            Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve) avg_achieve from user group by name");
            rows1.show();
    
            Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve) avg_achieve from user group by name");
            rows2.show();
        }
    
    }

    输出结果:

    +---+--------+---------+-------+
    | id|    name|  subject|achieve|
    +---+--------+---------+-------+
    |  1|zhangsan|  English|   80.0|
    |  2|zhangsan|  History|   87.0|
    |  3|zhangsan|  Chinese|   88.0|
    |  4|zhangsan|Chemistry|   96.0|
    |  5|    lisi|  English|   70.0|
    |  6|    lisi|  Chinese|   74.0|
    |  7|    lisi|  History|   75.0|
    |  8|    lisi|Chemistry|   77.0|
    |  9|    lisi|  Physics|   79.0|
    | 10|    lisi|  Biology|   82.0|
    | 11|  wangwu|  English|   96.0|
    | 12|  wangwu|  Chinese|   98.0|
    | 13|  wangwu|  History|   91.0|
    | 14| zhaoliu|  English|   68.0|
    | 15| zhaoliu|  Chinese|   66.0|
    +---+--------+---------+-------+
    
    +--------+-----------------+
    |    name|      avg_achieve|
    +--------+-----------------+
    |  wangwu|             95.0|
    | zhaoliu|             67.0|
    |zhangsan|            87.75|
    |    lisi|76.16666666666667|
    +--------+-----------------+
    
    +--------+-----------+
    |    name|avg_achieve|
    +--------+-----------+
    |  wangwu|       95.0|
    | zhaoliu|       67.0|
    |zhangsan|      87.75|
    |    lisi|      76.17|
    +--------+-----------+

    实现多列之和,再求平均数的UDAF聚合函数:

    package com.dx.streaming.producer;
    
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Encoders;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SparkSession;
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    public class TestUDAF1 {
    
        public static void main(String[] args) {
            SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
            Dataset<String> row = sparkSession.createDataset(Arrays.asList(
                    "1,zhangsan,English,80,89",
                    "2,zhangsan,History,87,88",
                    "3,zhangsan,Chinese,88,87",
                    "4,zhangsan,Chemistry,96,95",
                    "5,lisi,English,70,75",
                    "6,lisi,Chinese,74,67",
                    "7,lisi,History,75,80",
                    "8,lisi,Chemistry,77,70",
                    "9,lisi,Physics,79,80",
                    "10,lisi,Biology,82,83",
                    "11,wangwu,English,96,84",
                    "12,wangwu,Chinese,98,64",
                    "13,wangwu,History,91,92",
                    "14,zhaoliu,English,68,80",
                    "15,zhaoliu,Chinese,66,69"), Encoders.STRING());
            JavaRDD<String> javaRDD = row.javaRDD();
            JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
                private static final long serialVersionUID = -4769584490875182711L;
    
                @Override
                public Row call(String line) throws Exception {
                    String[] fields = line.split(",");
                    Integer id=Integer.parseInt(fields[0]);
                    String name=fields[1];
                    String subject=fields[2];
                    Double achieve1=Double.parseDouble(fields[3]);
                    Double achieve2=Double.parseDouble(fields[4]);
                    return RowFactory.create(id,name,subject,achieve1,achieve2);
                }
            });
    
            List<StructField> fields = new ArrayList<StructField>();
            fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
            fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
            fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
            fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false));
            fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false));
    
            StructType schema = DataTypes.createStructType(fields);
            Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
            ds.show();
            ds.createOrReplaceTempView("user");
    
            UserDefinedAggregateFunction udaf=new MutilAvg(2);
            sparkSession.udf().register("avg_format", udaf);
    
            Dataset<Row> rows1 = sparkSession.sql("select name,avg(achieve1+achieve2) avg_achieve from user group by name");
            rows1.show();
    
            Dataset<Row> rows2 = sparkSession.sql("select name,avg_format(achieve1,achieve2) avg_achieve from user group by name");
            rows2.show();
        }
    }

    上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilAvg实现的就是一个多列求和之后在进行求平均的使用。

    MutilAvg.java(udaf函数):

    package com.dx.streaming.producer;
    
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.expressions.MutableAggregationBuffer;
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
    import org.apache.spark.sql.types.DataType;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructType;
    
    public class MutilAvg extends UserDefinedAggregateFunction {
        private static final long serialVersionUID = 3924913264741215131L;
        private int columnSize=1;
        
        public MutilAvg(int columnSize){
            this.columnSize=columnSize;
        }
        
        @Override
        public StructType inputSchema() {
            StructType structType=     new StructType();
            for(int i=0;i<columnSize;i++){
                structType.add("myinput"+i,DataTypes.DoubleType);
            }
            return structType;
        }
            
        @Override
        public StructType bufferSchema() {
            StructType structType= new StructType().add("mycnt", DataTypes.LongType).add("mysum", DataTypes.DoubleType);
            return structType;
        }
    
        @Override
        public DataType dataType() {        
            return DataTypes.DoubleType;
        }
    
        @Override
        public boolean deterministic() {
            return true;
        }
    
       //设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
       //即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
        @Override
        public void initialize(MutableAggregationBuffer buffer) {
            buffer.update(0, 0l); // 用于存储mysum 0L=>是一个长整型Long类型的0
            buffer.update(1, 0d); // 用于存储mycnt 0D=>是一个长整型Double类型的0
        }
    
        /**
         * partitions内部combine
         * */
        //用输入数据input更新buffer值,类似于combineByKey
        @Override
        public void update(MutableAggregationBuffer buffer, Row input) {
            buffer.update(0, buffer.getLong(0)+1);                     // 條目數+1
            
            // 输入一行包含多列,因此需要把铜一行的多列合并。
            Double currentLineSumValue= 0d;
            for(int i=0;i<columnSize;i++){
                currentLineSumValue+=input.getDouble(i);
            }
            
            buffer.update(1, buffer.getDouble(1)+currentLineSumValue); // 输入汇总
        }
    
        /**
         * partitions间合并:MutableAggregationBuffer继承自Row。
         * */   
        //合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
        //这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。    
        @Override
        public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
            buffer1.update(0, buffer1.getLong(0)+buffer2.getLong(0));     // 條目數合併
            buffer1.update(1, buffer1.getDouble(1)+buffer2.getDouble(1)); // 输入汇总合併
        }
        
        //计算并返回最终的聚合结果
        @Override
        public Object evaluate(Row buffer) {
            // 计算平均值
            Double avg = buffer.getDouble(1) / buffer.getLong(0);
            Double avgFormat = Double.parseDouble(String.format("%.2f", avg));
    
            return avgFormat;
        }
    }
    View Code

    测试输出:

            +---+--------+---------+--------+--------+
            | id|    name|  subject|achieve1|achieve2|
            +---+--------+---------+--------+--------+
            |  1|zhangsan|  English|    80.0|    89.0|
            |  2|zhangsan|  History|    87.0|    88.0|
            |  3|zhangsan|  Chinese|    88.0|    87.0|
            |  4|zhangsan|Chemistry|    96.0|    95.0|
            |  5|    lisi|  English|    70.0|    75.0|
            |  6|    lisi|  Chinese|    74.0|    67.0|
            |  7|    lisi|  History|    75.0|    80.0|
            |  8|    lisi|Chemistry|    77.0|    70.0|
            |  9|    lisi|  Physics|    79.0|    80.0|
            | 10|    lisi|  Biology|    82.0|    83.0|
            | 11|  wangwu|  English|    96.0|    84.0|
            | 12|  wangwu|  Chinese|    98.0|    64.0|
            | 13|  wangwu|  History|    91.0|    92.0|
            | 14| zhaoliu|  English|    68.0|    80.0|
            | 15| zhaoliu|  Chinese|    66.0|    69.0|
            +---+--------+---------+--------+--------+
    
            +--------+-----------+
            |    name|avg_achieve|
            +--------+-----------+
            |  wangwu|      175.0|
            | zhaoliu|      141.5|
            |zhangsan|      177.5|
            |    lisi|      152.0|
            +--------+-----------+
    
            +--------+-----------+
            |    name|avg_achieve|
            +--------+-----------+
            |  wangwu|      175.0|
            | zhaoliu|      141.5|
            |zhangsan|      177.5|
            |    lisi|      152.0|
            +--------+-----------+

    实现多列分别求最大值,之后再从多列中最大值中找出一个最大的值的UDAF聚合函数:

    package com.dx.streaming.producer;
    
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.Encoders;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SparkSession;
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    public class TestUDAF2 {
    
        public static void main(String[] args) {
            SparkSession sparkSession = SparkSession.builder().appName("spark udf test").master("local[2]").getOrCreate();
            Dataset<String> row = sparkSession.createDataset(Arrays.asList(
                    "1,zhangsan,English,80,89",
                    "2,zhangsan,History,87,88",
                    "3,zhangsan,Chinese,88,87",
                    "4,zhangsan,Chemistry,96,95",
                    "5,lisi,English,70,75",
                    "6,lisi,Chinese,74,67",
                    "7,lisi,History,75,80",
                    "8,lisi,Chemistry,77,70",
                    "9,lisi,Physics,79,80",
                    "10,lisi,Biology,82,83",
                    "11,wangwu,English,96,84",
                    "12,wangwu,Chinese,98,64",
                    "13,wangwu,History,91,92",
                    "14,zhaoliu,English,68,80",
                    "15,zhaoliu,Chinese,66,69"), Encoders.STRING());
            JavaRDD<String> javaRDD = row.javaRDD();
            JavaRDD<Row> rowRDD = javaRDD.map(new Function<String, Row>() {
                private static final long serialVersionUID = -4769584490875182711L;
    
                @Override
                public Row call(String line) throws Exception {
                    String[] fields = line.split(",");
                    Integer id=Integer.parseInt(fields[0]);
                    String name=fields[1];
                    String subject=fields[2];
                    Double achieve1=Double.parseDouble(fields[3]);
                    Double achieve2=Double.parseDouble(fields[4]);
                    return RowFactory.create(id,name,subject,achieve1,achieve2);
                }
            });
    
            List<StructField> fields = new ArrayList<StructField>();
            fields.add(DataTypes.createStructField("id", DataTypes.IntegerType, true));
            fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
            fields.add(DataTypes.createStructField("subject", DataTypes.StringType, true));
            fields.add(DataTypes.createStructField("achieve1", DataTypes.DoubleType, false));
            fields.add(DataTypes.createStructField("achieve2", DataTypes.DoubleType, false));
    
            StructType schema = DataTypes.createStructType(fields);
            Dataset<Row> ds = sparkSession.createDataFrame(rowRDD, schema);
            ds.show();
    
            ds.createOrReplaceTempView("user");
    
            UserDefinedAggregateFunction udaf=new MutilMax(2,0);
            sparkSession.udf().register("max_vals", udaf);
    
            Dataset<Row> rows1 = sparkSession.sql(""
                    + "select name,max(achieve) as max_achieve "
                    + "from "
                    + "("
                    + "select name,max(achieve1) achieve from user group by name "
                    + "union all "
                    + "select name,max(achieve2) achieve from user group by name "
                    + ") t10 "
                    + "group by name");
            rows1.show();
    
            Dataset<Row> rows2 = sparkSession.sql("select name,max_vals(achieve1,achieve2) as max_achieve from user group by name");
            rows2.show();
        }
    }

    上边创建了一个DataSet,包含列:id,name,achieve1,achieve2,使用其中MutilMax实现的就是一个多列分别求出各自列的最大值,再从这些列的最大值中找出最大的一个值作为返回的最大值。

    MutilMax.java(udaf函数):

    package com.dx.streaming.producer;
    
    import java.util.ArrayList;
    import java.util.List;
    
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.expressions.MutableAggregationBuffer;
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
    import org.apache.spark.sql.types.DataType;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    public class MutilMax extends UserDefinedAggregateFunction {
        private static final long serialVersionUID = 3924913264741215131L;
        private int columnSize = 1;
        private Double defaultValue;
    
        public MutilMax(int columnSize, double defaultValue) {
            this.columnSize = columnSize;
            this.defaultValue = defaultValue;
        }
    
        @Override
        public StructType inputSchema() {
            List<StructField> inputFields = new ArrayList<StructField>();
            for (int i = 0; i < this.columnSize; i++) {
                inputFields.add(DataTypes.createStructField("myinput" + i, DataTypes.DoubleType, true));
            }
            StructType inputSchema = DataTypes.createStructType(inputFields);
            return inputSchema;
        }
    
        @Override
        public StructType bufferSchema() {
            List<StructField> bufferFields = new ArrayList<StructField>();
            for (int i = 0; i < this.columnSize; i++) {
                bufferFields.add(DataTypes.createStructField("mymax" + i, DataTypes.DoubleType, true));
            }
            StructType bufferSchema = DataTypes.createStructType(bufferFields);
            return bufferSchema;
        }
    
        @Override
        public DataType dataType() {
            return DataTypes.DoubleType;
        }
    
        @Override
        public boolean deterministic() {
            return false;
        }
    
        // 设置聚合中间buffer的初始值,但需要保证这个语义:两个初始buffer调用下面实现的merge方法后也应该为初始buffer。
        // 即如果你初始值是1,然后你merge是执行一个相加的动作,两个初始buffer合并之后等于2,不会等于初始buffer了。这样的初始值就是有问题的,所以初始值也叫"zero value"
        @Override
        public void initialize(MutableAggregationBuffer buffer) {
            for (int i = 0; i < this.columnSize; i++) {
                buffer.update(i, 0d);
            }
        }
    
        /**
         * partitions内部combine
         */
        // 用输入数据input更新buffer值,类似于combineByKey
        @Override
        public void update(MutableAggregationBuffer buffer, Row input) {
            for (int i = 0; i < this.columnSize; i++) {
                if( buffer.getDouble(i) >input.getDouble(i)){
                    buffer.update(i, buffer.getDouble(i));
                }else{
                    buffer.update(i, input.getDouble(i));
                }
            }
        }
    
        /**
         * partitions间合并:MutableAggregationBuffer继承自Row。
         */
        // 合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
        // 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节。
        @Override
        public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
            for (int i = 0; i < this.columnSize; i++) {
                if( buffer1.getDouble(i) >buffer2.getDouble(i)){
                    buffer1.update(i, buffer1.getDouble(i));
                }else{
                    buffer1.update(i, buffer2.getDouble(i));
                }
            }
        }
    
        // 计算并返回最终的聚合结果
        @Override
        public Object evaluate(Row buffer) {
            // 计算平均值
            Double max = Double.MIN_VALUE;
            for (int i = 0; i < this.columnSize; i++) {
                if (buffer.getDouble(i) > max) {
                    max = buffer.getDouble(i);
                }
            }
    
            if (max == Double.MIN_VALUE) {
                max = this.defaultValue;
            }
    
            return max;
        }
    
    }
    View Code

    打印结果:

            +---+--------+---------+--------+--------+
            | id|    name|  subject|achieve1|achieve2|
            +---+--------+---------+--------+--------+
            |  1|zhangsan|  English|    80.0|    89.0|
            |  2|zhangsan|  History|    87.0|    88.0|
            |  3|zhangsan|  Chinese|    88.0|    87.0|
            |  4|zhangsan|Chemistry|    96.0|    95.0|
            |  5|    lisi|  English|    70.0|    75.0|
            |  6|    lisi|  Chinese|    74.0|    67.0|
            |  7|    lisi|  History|    75.0|    80.0|
            |  8|    lisi|Chemistry|    77.0|    70.0|
            |  9|    lisi|  Physics|    79.0|    80.0|
            | 10|    lisi|  Biology|    82.0|    83.0|
            | 11|  wangwu|  English|    96.0|    84.0|
            | 12|  wangwu|  Chinese|    98.0|    64.0|
            | 13|  wangwu|  History|    91.0|    92.0|
            | 14| zhaoliu|  English|    68.0|    80.0|
            | 15| zhaoliu|  Chinese|    66.0|    69.0|
            +---+--------+---------+--------+--------+
    
            +--------+-----------+
            |    name|max_achieve|
            +--------+-----------+
            |  wangwu|       98.0|
            | zhaoliu|       80.0|
            |zhangsan|       96.0|
            |    lisi|       83.0|
            +--------+-----------+
    
            +--------+-----------+
            |    name|max_achieve|
            +--------+-----------+
            |  wangwu|       98.0|
            | zhaoliu|       80.0|
            |zhangsan|       96.0|
            |    lisi|       83.0|
            +--------+-----------+

    Spark编写Agg函数

    实现一个avg函数:

    第一步:定义一个Average,用来存储count,sum;

    import java.io.Serializable;
    
    public class Average implements Serializable {
        private long sum;
        private long count;
    
        // Constructors, getters, setters...
        public long getSum() {
            return sum;
        }
    
        public void setSum(long sum) {
            this.sum = sum;
        }
    
        public long getCount() {
            return count;
        }
    
        public void setCount(long count) {
            this.count = count;
        }
    
        public Average() {
    
        }
    
        public Average(long sum, long count) {
            this.sum = sum;
            this.count = count;
        }
    }
    View Code

    第二步:定义一个Employee,存储员工信息:员工名称、员工薪资;

    import java.io.Serializable;
    
    public class Employee implements Serializable {
        private String name;
        private long salary;
    
        // Constructors, getters, setters...
        public String getName() {
            return name;
        }
    
        public void setName(String name) {
            this.name = name;
        }
    
        public long getSalary() {
            return salary;
        }
    
        public void setSalary(long salary) {
            this.salary = salary;
        }
    
        public Employee() {
        }
    
        public Employee(String name, long salary) {
            this.name = name;
            this.salary = salary;
        }
    }
    View Code

    第三步:定义一个Agg,实现对员工的薪资avg功能;

    import org.apache.spark.sql.Encoder;
    import org.apache.spark.sql.Encoders;
    import org.apache.spark.sql.expressions.Aggregator;
    
    public class MyAverage extends Aggregator<Employee, Average, Double> {
        // A zero value for this aggregation. Should satisfy the property that any b + zero = b
        @Override
        public Average zero() {
            return new Average(0L, 0L);
        }
    
        // Combine two values to produce a new value. For performance, the function may modify `buffer`
        // and return it instead of constructing a new object
        @Override
        public Average reduce(Average buffer, Employee employee) {
            long newSum = buffer.getSum() + employee.getSalary();
            long newCount = buffer.getCount() + 1;
            buffer.setSum(newSum);
            buffer.setCount(newCount);
            return buffer;
        }
    
        // Merge two intermediate values
        @Override
        public Average merge(Average b1, Average b2) {
            long mergedSum = b1.getSum() + b2.getSum();
            long mergedCount = b1.getCount() + b2.getCount();
            b1.setSum(mergedSum);
            b1.setCount(mergedCount);
            return b1;
        }
    
        // Transform the output of the reduction
        @Override
        public Double finish(Average reduction) {
            return ((double) reduction.getSum()) / reduction.getCount();
        }
    
        // Specifies the Encoder for the intermediate value type
        @Override
        public Encoder<Average> bufferEncoder() {
            return Encoders.bean(Average.class);
        }
    
        // Specifies the Encoder for the final output value type
        @Override
        public Encoder<Double> outputEncoder() {
            return Encoders.DOUBLE();
        }
    }

    第四步:spark调用agg,验证。

    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.MapFunction;
    import org.apache.spark.sql.*;;
    
    import java.util.ArrayList;
    import java.util.List;
    
    public class SparkClient {
        public static void main(String[] args) {
            final SparkSession spark = SparkSession.builder().master("local[*]").appName("test_agg").getOrCreate();
            final JavaSparkContext ctx = JavaSparkContext.fromSparkContext(spark.sparkContext());
    
            List<Employee> employeeList = new ArrayList<Employee>();
            employeeList.add(new Employee("Michael", 3000L));
            employeeList.add(new Employee("Andy", 4500L));
            employeeList.add(new Employee("Justin", 3500L));
            employeeList.add(new Employee("Berta", 4000L));
    
            JavaRDD<Employee> rows = ctx.parallelize(employeeList);
            Dataset<Employee> ds = spark.createDataFrame(rows, Employee.class).map(new MapFunction<Row, Employee>() {
                @Override
                public Employee call(Row row) throws Exception {
                    return new Employee(row.getString(0), row.getLong(1));
                }
            }, Encoders.bean(Employee.class));
    
            ds.show();
            // +-------+------+
            // |   name|salary|
            // +-------+------+
            // |Michael|  3000|
            // |   Andy|  4500|
            // | Justin|  3500|
            // |  Berta|  4000|
            // +-------+------+
    
            MyAverage myAverage = new MyAverage();
            // Convert the function to a `TypedColumn` and give it a name
            TypedColumn<Employee, Double> averageSalary = myAverage.toColumn().name("average_salary");
            Dataset<Double> result = ds.select(averageSalary);
            result.show();
            // +--------------+
            // |average_salary|
            // +--------------+
            // |        3750.0|
            // +--------------+
        }
    }

    输出:

    +-------+------+
    |   name|salary|
    +-------+------+
    |Michael|  3000|
    |   Andy|  4500|
    | Justin|  3500|
    |  Berta|  4000|
    +-------+------+
    
    +--------------+
    |average_salary|
    +--------------+
    |        3750.0|
    +--------------+

    参考:

    https://www.cnblogs.com/LHWorldBlog/p/8432210.html

    https://blog.csdn.net/kwu_ganymede/article/details/50462020

    https://my.oschina.net/cloudcoder/blog/640009

    https://blog.csdn.net/xgjianstart/article/details/54956413

  • 相关阅读:
    php 单双引号的区别
    SpringBoot动态代理使用Cglib还是jdk的问题
    SpringBoot MyBatis 一级缓存和二级的配置理解
    SpringBoot+MyBatis,显示SQL方式
    java lambda分页
    关于Spring的@Value注解使用Integer方式
    mysql死锁,等待资源,事务锁,Lock wait timeout exceeded; try restarting transaction解决
    关于Integer包装类对象之间值的比较
    你未必了解的DNS
    SpringCloudConfig报错Cannot clone or checkout repository:https://gitee.com/yanfa401/config-repo
  • 原文地址:https://www.cnblogs.com/yy3b2007com/p/9294345.html
Copyright © 2011-2022 走看看