zoukankan      html  css  js  c++  java
  • 转:Spark User Defined Aggregate Function (UDAF) using Java

    Sometimes the aggregate functions provided by Spark are not adequate, so Spark has a provision of accepting custom user defined aggregate functions. Before diving into code lets first understand some of the methods of class UserDefinedAggregateFunction.

    1. inputSchema()

    In this method you need to define a StructType that represents the input arguments of this aggregate function.

    2. bufferSchema()

    In this method you need to define a StructType that represents values in the aggregation buffer. This schema is used to hold the aggregate function value at the time of processing.

    3. dataType()

    The DataType of the returned value of this aggregate function

    4. initialize(MutableAggregationBuffer buffer)

    Whenever your “key” changes this method is invoked. You can use this method to reinitalise your variable.

    5. evaluate(Row buffer)

    This method calculates the final value by refering the aggregation buffer.

    6. update(MutableAggregationBuffer buffer, Row input)

    This method is used to update the aggregation buffer, it is invoked every time a new input comes for similar key

    7. merge(MutableAggregationBuffer buffer, Row input)

    This method is used to merge output of two different aggregation buffer.

    Below is the pictorial representation of how the methods work in spark.Assumption is, there are 2 aggregation buffers for your task

    blog

    Lets see how we can write a UDAF that accepts multiple values as input and returns multiple values as output.

    My input file is a .txt file which contains 3 columns city, female count and male count.We need to compute total population and the dominant population of each city.

    CITIES.TXT

    Nashik 40 50
    Mumbai 50 60
    Pune 70 80
    Nashik 40 50
    Mumbai 50 60
    Pune 170 80

    Expected output is as below

    +--------+--------+--------+
    | city   |Dominant| Total  |
    +--------+--------+--------+
    | Mumbai | Male   | 220    |
    | Pune   | Female | 400    |
    | Nashik | Male   | 180    |
    +--------+--------+--------+

    Now lets write a UDAF class that extends UserDefinedAggregateFunction class, I have provided the required comments in the code below.

    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.expressions.MutableAggregationBuffer;
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
    import org.apache.spark.sql.types.DataType;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    public class SparkUDAF extends UserDefinedAggregateFunction
    {
    private StructType inputSchema;
    private StructType bufferSchema;
    private DataType returnDataType =
    DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType);
    MutableAggregationBuffer mutableBuffer;
    
    public SparkUDAF()
    {
    //inputSchema : This UDAF can accept 2 inputs which are of type Integer
    List<StructField> inputFields = new ArrayList<StructField>();
    StructField inputStructField1 = DataTypes.createStructField(“femaleCount”,DataTypes.IntegerType, true);
    inputFields.add(inputStructField1);
    StructField inputStructField2 = DataTypes.createStructField(“maleCount”,DataTypes.IntegerType, true);
    inputFields.add(inputStructField2);
    inputSchema = DataTypes.createStructType(inputFields);
    
    //BufferSchema : This UDAF can hold calculated data in below mentioned buffers
    List<StructField> bufferFields = new ArrayList<StructField>();
    StructField bufferStructField1 = DataTypes.createStructField(“totalCount”,DataTypes.IntegerType, true);
    bufferFields.add(bufferStructField1);
    StructField bufferStructField2 = DataTypes.createStructField(“femaleCount”,DataTypes.IntegerType, true);
    bufferFields.add(bufferStructField2);
    StructField bufferStructField3 = DataTypes.createStructField(“maleCount”,DataTypes.IntegerType, true);
    bufferFields.add(bufferStructField3);
    StructField bufferStructField4 = DataTypes.createStructField(“outputMap”,DataTypes.createMapType(DataTypes.StringType, DataTypes.StringType), true);
    bufferFields.add(bufferStructField4);
    bufferSchema = DataTypes.createStructType(bufferFields);
    }
    
    /**
    * This method determines which bufferSchema will be used
    */
    @Override
    public StructType bufferSchema() {
    
    return bufferSchema;
    }
    
    /**
    * This method determines the return type of this UDAF
    */
    @Override
    public DataType dataType() {
    return returnDataType;
    }
    
    /**
    * Returns true iff this function is deterministic, i.e. given the same input, always return the same output.
    */
    @Override
    public boolean deterministic() {
    return true;
    }
    
    /**
    * This method will re-initialize the variables to 0 on change of city name
    */
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
    buffer.update(0, 0);
    buffer.update(1, 0);
    buffer.update(2, 0);
    mutableBuffer = buffer;
    }
    
    /**
    * This method is used to increment the count for each city
    */
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
    buffer.update(0, buffer.getInt(0) + input.getInt(0) + input.getInt(1));
    buffer.update(1, input.getInt(0));
    buffer.update(2, input.getInt(1));
    }
    
    /**
    * This method will be used to merge data of two buffers
    */
    @Override
    public void merge(MutableAggregationBuffer buffer, Row input) {
    
    buffer.update(0, buffer.getInt(0) + input.getInt(0));
    buffer.update(1, buffer.getInt(1) + input.getInt(1));
    buffer.update(2, buffer.getInt(2) + input.getInt(2));
    
    }
    
    /**
    * This method calculates the final value by referring the aggregation buffer
    */
    @Override
    public Object evaluate(Row buffer) {
    //In this method we are preparing a final map that will be returned as output
    Map<String,String> op = new HashMap<String,String>();
    op.put(“Total”, “” + mutableBuffer.getInt(0));
    op.put(“dominant”, “Male”);
    if(buffer.getInt(1) > mutableBuffer.getInt(2))
    {
    op.put(“dominant”, “Female”);
    }
    mutableBuffer.update(3,op);
    
    return buffer.getMap(3);
    }
    /**
    * This method will determine the input schema of this UDAF
    */
    @Override
    public StructType inputSchema() {
    
    return inputSchema;
    }
    
    }
    
    Now lets see how we can access this UDAF using our spark code
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    import java.util.StringTokenizer;
    
    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.sql.DataFrame;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SQLContext;
    import org.apache.spark.sql.hive.HiveContext;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    public class TestDemo {
    public static void main (String args[])
    {
    //Set up sparkContext and SQLContext
    SparkConf conf = new SparkConf().setAppName(“udaf”).setMaster(“local”);
    JavaSparkContext sc = new JavaSparkContext(conf);
    SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
    
    //create Row RDD
    JavaRDD<String> citiesRdd = sc.textFile(“cities.txt”);
    JavaRDD<Row> rowRdd = citiesRdd.map(new Function<String, Row>() {
    public Row call(String line) throws Exception {
    StringTokenizer st = new StringTokenizer(line,” “);
    return RowFactory.create(st.nextToken().trim(),Integer.parseInt(st.nextToken().trim()),Integer.parseInt(st.nextToken().trim()));
    }
    });
    
    //Create Struct Type
    List<StructField> inputFields = new ArrayList<StructField>();
    StructField inputStructField = DataTypes.createStructField(“city”,DataTypes.StringType, true);
    inputFields.add(inputStructField);
    StructField inputStructField2 = DataTypes.createStructField(“Female”,DataTypes.IntegerType, true);
    inputFields.add(inputStructField2);
    StructField inputStructField3 = DataTypes.createStructField(“Male”,DataTypes.IntegerType, true);
    inputFields.add(inputStructField3);
    StructType inputSchema = DataTypes.createStructType(inputFields);
    
    //Create Data Frame
    DataFrame df = sqlContext.createDataFrame(rowRdd, inputSchema);
    
    //Register our Spark UDAF
    SparkUDAF sparkUDAF = new SparkUDAF();
    sqlContext.udf().register(“uf”,sparkUDAF);
    
    //Register dataframe as table
    df.registerTempTable(“cities”);
    
    //Run query
    sqlContext.sql(“SELECT city , count[‘dominant’] as Dominant, count[‘Total’] as Total from(select city, uf(Female,Male) as count from cities group by (city)) temp”).show(false);
    
    }
    }

    文章来自:https://blog.augmentiq.in/2016/08/05/spark-multiple-inputoutput-user-defined-aggregate-function-udaf-using-java/

  • 相关阅读:
    iqueryable lambda表达式
    win10安装后耳机有声音而外放无声音
    Coursera机器学习week11 笔记
    Coursera机器学习week10 单元测试
    Coursera机器学习week10 笔记
    Coursera机器学习week9 编程作业
    Coursera机器学习week9 单元测试
    Coursera机器学习week9 笔记
    Coursera机器学习week8 编程作业
    Coursera机器学习week8 单元测试
  • 原文地址:https://www.cnblogs.com/leixingzhi7/p/6213714.html
Copyright © 2011-2022 走看看