zoukankan      html  css  js  c++  java
  • spark-sql自定义函数UDF和UDAF

    1 UDF对每个值进行处理;

    2 UDAF对分组后的每个值处理(必须分组)

        SparkConf sparkConf = new SparkConf()
                    .setMaster("local")
                    .setAppName("MySqlTest");
    
            JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
    
            SQLContext sqlContext = new SQLContext(javaSparkContext);
    
            List<String> list = new ArrayList<String>();
            list.add("2018-9-9,1,ab");
            list.add("2018-5-9,1124,abg");
            list.add("2018-9-9,1125,abc");
            list.add("2018-5-9,1126,abh");
            list.add("2016-10-9,1127,abc");
            list.add("2016-10-9,1127,abcd");
            list.add("2016-10-9,1127,abcder");
    
            JavaRDD<String> rdd_list = javaSparkContext.parallelize(list, 5);
    
            JavaRDD<Row> rdd_row_list = rdd_list.map(new Function<String, Row>() {
                @Override
                public Row call(String s) throws Exception {
                    return RowFactory.create(s.split(",")[0], Long.parseLong(s.split(",")[1]), s.split(",")[2]);//转换成一个row对象
                }
            });
    
            List<StructField> structFieldList = new ArrayList<StructField>();
            structFieldList.add(DataTypes.createStructField("date", DataTypes.StringType, true));
            structFieldList.add(DataTypes.createStructField("s", DataTypes.LongType, true));
            structFieldList.add(DataTypes.createStructField("str", DataTypes.StringType, true));
            StructType dyType = DataTypes.createStructType(structFieldList);
    
            DataFrame df_dyType = sqlContext.createDataFrame(rdd_row_list, dyType);
    
            df_dyType.registerTempTable("tmp_req");
            df_dyType.show();
    
            //1,注册一个简单用户自定义函数
            sqlContext.udf().register("zzq123", new UDF1<String, Integer>() {
                @Override
                public Integer call(String str) throws Exception {
                    return str.length();
                }
            }, DataTypes.IntegerType);
    
            DataFrame df_group = sqlContext.sql("select date,s,zzq123(date) as zzq123 from tmp_req ");//UDF如果没有指定名称,则随机名称
            df_group.show();
    
            //1,注册一个复杂的用户自定义聚合函数
            sqlContext.udf().register("zzq_agg", new StringLen());//zzq_agg函数计算出分组后本组所有字符串总长度
            DataFrame df_group_agg = sqlContext.sql("select date,zzq_agg(str) strSum  from tmp_req group by date ");//UDAF为聚合情况下使用
            df_group_agg.show();

    UDAF实体:

    public class StringLen extends UserDefinedAggregateFunction {
        @Override
        public StructType inputSchema() {//inputSchema指的是输入的数据类型
            List<StructField> fields = new ArrayList<StructField>();
            fields.add(DataTypes.createStructField("_string", DataTypes.StringType, true));
            return DataTypes.createStructType(fields);
        }
    
        @Override
        public StructType bufferSchema() {//bufferSchema指的是  中间进行聚合时  所处理的数据类型
            List<StructField> fields = new ArrayList<StructField>();
            fields.add(DataTypes.createStructField("_len", DataTypes.IntegerType, true));
            return DataTypes.createStructType(fields);
        }
    
        @Override
        public DataType dataType() {//dataType指的是函数返回值的类型
            return DataTypes.IntegerType;
        }
    
        @Override
        public boolean deterministic() {//一致性检验,如果为true,那么输入不变的情况下计算的结果也是不变的
            return true;
        }
    
        /**
         * 对于每个分组的数据进行最原始的初始化操作
         *
         * @param buffer
         */
        @Override
        public void initialize(MutableAggregationBuffer buffer) {
            buffer.update(0, 0);//初始化的时候初始最开始的字符串的长度
        }
    
        /**
         * 用输入数据input更新buffer值,类似于combineByKey
         *
         * @param buffer
         * @param input
         */
        @Override
        public void update(MutableAggregationBuffer buffer, Row input) {//分组后的每个值处理方法
            buffer.update(0, ((Integer) buffer.getAs(0)) + input.getAs(0).toString().length());//返回自己的长度
        }
    
        /**
         * 合并两个buffer,将buffer2合并到buffer1.在合并两个分区聚合结果的时候会被用到,类似于reduceByKey
         * 这里要注意该方法没有返回值,在实现的时候是把buffer2合并到buffer1中去,你需要实现这个合并细节
         *
         * @param buffer1
         * @param buffer2
         */
        @Override
        public void merge(MutableAggregationBuffer buffer1, Row buffer2) {//相当于shuffle环节,将每组在不同executor上的数据进行combiner
            buffer1.update(0, ((Integer) buffer1.getAs(0)) + ((Integer) buffer2.getAs(0)));//两次的字符串长度相加
        }
    
        /**
         * 计算并返回最终的聚合结果
         *
         * @param buffer
         * @return
         */
        @Override
        public Object evaluate(Row buffer) {
            return buffer.getInt(0);
        }
    }
  • 相关阅读:
    acm课程练习2--1005
    acm课程练习2--1003
    [ZJOI2010]网络扩容
    [ZJOI2009]狼和羊的故事
    [FJOI2007]轮状病毒
    [NOIP2016提高组]换教室
    [NOIP2016提高组]愤怒的小鸟
    [NOIP2009提高组]最优贸易
    [洛谷P2245]星际导航
    [NOIP2013提高组]货车运输
  • 原文地址:https://www.cnblogs.com/zzq-include/p/8758961.html
Copyright © 2011-2022 走看看