zoukankan      html  css  js  c++  java
  • 【Spark-SQL学习之三】 UDF、UDAF、开窗函数

    环境
      虚拟机:VMware 10
      Linux版本:CentOS-6.5-x86_64
      客户端:Xshell4
      FTP:Xftp4
      jdk1.8
      scala-2.10.4(依赖jdk1.8)
      spark-1.6


    一、UDF:用户自定义函数。
    可以自定义类实现UDFX接口

    示例代码:
    Java:

    package com.wjy.df;
    
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    
    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.sql.DataFrame;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SQLContext;
    import org.apache.spark.sql.api.java.UDF1;
    import org.apache.spark.sql.api.java.UDF2;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    public class UDF {
    
        public static void main(String[] args) {
            SparkConf conf = new SparkConf().setMaster("local").setAppName("UDF");
            JavaSparkContext sc = new JavaSparkContext(conf);
            SQLContext sqlContext = new SQLContext(sc);
            
            JavaRDD<String> rdd = sc.parallelize(Arrays.asList("xiaoming","xiaohong","xiaolei"));
            JavaRDD<Row> rdd2 = rdd.map(new Function<String, Row>() {
                private static final long serialVersionUID = 1L;
                @Override
                public Row call(String str) throws Exception {
                    return RowFactory.create(str);
                }
            });
            
            /**
             * 动态创建Schema方式加载DF
             */
            List<StructField> fields = new ArrayList<StructField>();
            fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
            StructType schema = DataTypes.createStructType(fields);
            DataFrame dataFrame = sqlContext.createDataFrame(rdd2, schema);
            dataFrame.registerTempTable("user");
            
            //定义一个统计字符串长度的函数
            /**
             * 根据UDF函数参数的个数来决定是实现哪一个UDF  UDF1,UDF2。。。。UDF1xxx
             */
            sqlContext.udf().register("StrLen", new UDF1<String, Integer>() {
                private static final long serialVersionUID = 1L;
    
                @Override
                public Integer call(String str) throws Exception {
                    return str.length();
                }
            },DataTypes.IntegerType);
            sqlContext.sql("select name ,StrLen(name) as length from user").show();
            /*
             * +--------+------+
               |    name|length|
               +--------+------+
               |xiaoming|     8|
               |xiaohong|     8|
                | xiaolei|     7|
               +--------+------+
             */
            
            sqlContext.udf().register("StrLen2", new UDF2<String, Integer, Integer>() {
                private static final long serialVersionUID = 1L;
    
                @Override
                public Integer call(String str, Integer num) throws Exception {
                    return str.length()+num;
                }
            }, DataTypes.IntegerType);
            sqlContext.sql("select name ,StrLen2(name,10) as length from user").show();
            /*
             * +--------+------+
               |    name|length|
               +--------+------+
               |xiaoming|    18|
                 |xiaohong|    18|
               | xiaolei|    17|
               +--------+------+
             */
            
            sc.stop();
        }
    
    }

    Scala:

    package com.wjy.df
    
    import org.apache.spark.SparkConf
    import org.apache.spark.SparkContext
    import org.apache.spark.sql.RowFactory
    import org.apache.spark.sql.types.DataTypes
    import org.apache.spark.sql.types.StructField
    import org.apache.spark.sql.types.StringType
    import org.apache.spark.sql.SQLContext
    
    object UDF {
      def main(args:Array[String]):Unit={
        val conf = new SparkConf().setMaster("local").setAppName("");
        val sc = new SparkContext(conf);
        val sqlContext = new SQLContext(sc);
        val rdd = sc.makeRDD(Array("zhansan","lisi","wangwu"));
        val row = rdd.map(x=>{
          RowFactory.create(x);
        });
        val schema = DataTypes.createStructType(Array(StructField("name",StringType,true)));
        val df = sqlContext.createDataFrame(row, schema);
        df.show;//show方法可以没有()
        df.registerTempTable("user");
        
        //StrLen
        sqlContext.udf.register("StrLen", (s:String)=>{s.length()});
        sqlContext.sql("select name ,StrLen(name) as length from user").show;
        
        //StrLen2
        sqlContext.udf.register("StrLen2", (s:String,i:Integer)=>{s.length()+i});
        sqlContext.sql("select name ,StrLen2(name,10) as length from user").show;
        
        sc.stop();
      }
    }

    二、UDAF:用户自定义聚合函数。
    实现UDAF函数如果要自定义类要继承UserDefinedAggregateFunction类

    示例代码:
    Java:

    package com.wjy.df;
    
    import java.util.ArrayList;
    import java.util.Arrays;
    import java.util.List;
    
    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.Function;
    import org.apache.spark.sql.DataFrame;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.RowFactory;
    import org.apache.spark.sql.SQLContext;
    import org.apache.spark.sql.expressions.MutableAggregationBuffer;
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
    import org.apache.spark.sql.types.DataType;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructField;
    import org.apache.spark.sql.types.StructType;
    
    /**
     * UDAF 用户自定义聚合函数
     * @author root
     *
     */
    public class UDAF {
    
        public static void main(String[] args) {
            SparkConf conf = new SparkConf().setMaster("local").setAppName("UDAF");
            JavaSparkContext sc = new JavaSparkContext(conf);
            SQLContext sqlContext = new SQLContext(sc);
            JavaRDD<String> parallelize = sc.parallelize(
                    Arrays.asList("zhangsan","lisi","wangwu","zhangsan","zhangsan","lisi"));
            JavaRDD<Row> rowRDD = parallelize.map(new Function<String, Row>() {
                private static final long serialVersionUID = 1L;
                @Override
                public Row call(String s) throws Exception {
                    return RowFactory.create(s);
                }
            });
            
            List<StructField> fields = new ArrayList<StructField>();
            fields.add(DataTypes.createStructField("name", DataTypes.StringType, true));
            StructType schema = DataTypes.createStructType(fields);
            DataFrame df = sqlContext.createDataFrame(rowRDD, schema);
            df.registerTempTable("user");
            
            /**
             * 注册一个UDAF函数,实现统计相同值得个数
             * 注意:这里可以自定义一个类继承UserDefinedAggregateFunction类也是可以的
             */
            sqlContext.udf().register("StringCount",new UserDefinedAggregateFunction(){
                private static final long serialVersionUID = 1L;
    
                /**
                 * 初始化一个内部的自己定义的值,在Aggregate之前每组数据的初始化结果
                 */
                @Override
                public void initialize(MutableAggregationBuffer buffer) {
                    buffer.update(0, 0);
                }
                
                /**
                 * 指定输入字段的字段及类型
                 */
                @Override
                public StructType inputSchema() {
                    return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("name", DataTypes.StringType, true)));
                }
                
                /**
                 * 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
                 * buffer.getInt(0)获取的是上一次聚合后的值
                 * 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合 
                 * 大聚和发生在reduce端.
                 * 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
                 */
                @Override
                public void update(MutableAggregationBuffer buffer, Row arg1) {
                    buffer.update(0, buffer.getInt(0)+1);
                }
                
                /**
                 * 在进行聚合操作的时候所要处理的数据的结果的类型
                 */
                @Override
                public StructType bufferSchema() {
                    return DataTypes.createStructType(Arrays.asList(DataTypes.createStructField("buffer", DataTypes.IntegerType, true)));
                }
    
                /**
                 * 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
                 * 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
                 * buffer1.getInt(0) : 大聚合的时候 上一次聚合后的值       
                 * buffer2.getInt(0) : 这次计算传入进来的update的结果
                 * 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
                 */
                @Override
                public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
                    buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0));
                }
                
                /**
                 * 指定UDAF函数计算后返回的结果类型
                 */
                @Override
                public DataType dataType() {
                    return DataTypes.IntegerType;
                }
    
                /**
                 * 最后返回一个和dataType方法的类型要一致的类型,返回UDAF最后的计算结果
                 */
                @Override
                public Object evaluate(Row row) {
                    return row.getInt(0);
                }
                
                /**
                 * 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果。
                 */
                @Override
                public boolean deterministic() {
                    return true;
                }
    
                });
            
            sqlContext.sql("select name ,StringCount(name) as strCount from user group by name").show();
            sc.stop();
        }
    
    }

    Scala:

    package com.wjy.df
    
    import org.apache.spark.SparkConf
    import org.apache.spark.SparkContext
    import org.apache.spark.sql.SQLContext
    import org.apache.spark.sql.types.DataTypes
    import org.apache.spark.sql.types.StringType
    import org.apache.spark.sql.RowFactory
    import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
    import org.apache.spark.sql.types.IntegerType
    import org.apache.spark.sql.expressions.MutableAggregationBuffer
    import org.apache.spark.sql.Row
    import org.apache.spark.sql.types.StructType
    import org.apache.spark.sql.types.DataType
    
    class MyUDAF extends UserDefinedAggregateFunction{ 
      // 为每个分组的数据执行初始化值
      def initialize(buffer: MutableAggregationBuffer): Unit = {
         buffer(0) = 0
      }
      
      //输入数据的类型
      def inputSchema: StructType = {
        DataTypes.createStructType(Array(DataTypes.createStructField("input", StringType, true)))
      }
      
      // 每个组,有新的值进来的时候,进行分组对应的聚合值的计算
      def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer(0) = buffer.getAs[Int](0)+1
      }
      
      // 聚合操作时,所处理的数据的类型
      def bufferSchema: StructType = {
        DataTypes.createStructType(Array(DataTypes.createStructField("aaa", IntegerType, true)))
      }
      
      //最后merger的时候,在各个节点上的聚合值,要进行merge,也就是合并
      def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getAs[Int](0)+buffer2.getAs[Int](0) 
      }
      
      // 最终函数返回值的类型
      def dataType: DataType = {
        DataTypes.IntegerType
      }
      
      // 最后返回一个最终的聚合值     要和dataType的类型一一对应
      def evaluate(buffer: Row): Any = {
        buffer.getAs[Int](0)
      }
    
      //保持一致性
      def deterministic: Boolean = {
        true
      }
    }
    
    object UDAF {
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf()
        conf.setMaster("local").setAppName("udaf")
        val sc = new SparkContext(conf)
        val sqlContext = new SQLContext(sc)
        val rdd = sc.makeRDD(Array("zhangsan","lisi","wangwu","zhangsan","lisi"))
        val rowRDD = rdd.map { x => {RowFactory.create(x)} }
        
        val schema = DataTypes.createStructType(Array(DataTypes.createStructField("name", StringType, true)))
        val df = sqlContext.createDataFrame(rowRDD, schema)
        df.show()
        df.registerTempTable("user")
        /**
         * 注册一个udaf函数
         */
        sqlContext.udf.register("StringCount", new MyUDAF())
        sqlContext.sql("select name ,StringCount(name) as count from user group by name").show()
        sc.stop()
      }
    }

    三、开窗函数
    开窗函数格式:
    row_number() over (partitin by XXX order by XXX)
    注意:
    row_number() 开窗函数是按照某个字段分组,然后取另一字段的前几个的值,相当于分组取topN;
    如果SQL语句里面使用到了开窗函数,那么这个SQL语句必须使用HiveContext来执行,HiveContext默认情况下在本地无法创建。

    示例代码:
    Java:

    package com.wjy.df;
    
    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.sql.DataFrame;
    import org.apache.spark.sql.SaveMode;
    import org.apache.spark.sql.hive.HiveContext;
    
    public class RowNumberWindowFun {
    
        public static void main(String[] args) {
            SparkConf conf = new SparkConf();
            conf.setAppName("windowfun");
            conf.set("spark.sql.shuffle.partitions","1");
            JavaSparkContext sc = new JavaSparkContext(conf);
            HiveContext hiveContext = new HiveContext(sc);
            hiveContext.sql("use spark");
            hiveContext.sql("drop table if exists sales");
            hiveContext.sql("create table if not exists sales (riqi string,leibie string,jine Int) "
                    + "row format delimited fields terminated by '	'");
            hiveContext.sql("load data local inpath '/root/test/sales' into table sales");
            
            /**
             * 开窗函数格式:
             * 【 row_number() over (partition by XXX order by XXX DESC) as rank】
             * 注意:rank 从1开始
             */
            /**
             * 以类别分组,按每种类别金额降序排序,显示 【日期,种类,金额】 结果,如:
             * 
             * 1 A 100
             * 2 B 200
             * 3 A 300
             * 4 B 400
             * 5 A 500
             * 6 B 600
             * 排序后:
             * 5 A 500  --rank 1
             * 3 A 300  --rank 2 
             * 1 A 100  --rank 3
             * 6 B 600  --rank 1
             * 4 B 400    --rank 2
             * 2 B 200  --rank 3
             * 
             */
            DataFrame result = hiveContext.sql("select riqi,leibie,jine "
                                + "from ("
                                    + "select riqi,leibie,jine,"
                                    + "row_number() over (partition by leibie order by jine desc) rank "
                                    + "from sales) t "
                            + "where t.rank<=3");
            result.show(100);
            /**
             * 将结果保存到hive表sales_result
             */
            result.write().mode(SaveMode.Overwrite).saveAsTable("sales_result");
            sc.stop();
        }
    }

    Scala:

    package com.wjy.df
    
    import org.apache.spark.SparkConf
    import org.apache.spark.SparkContext
    import org.apache.spark.sql.hive.HiveContext
    
    object RowNumberWindowFun {
      val conf = new SparkConf()
        conf.setAppName("windowfun")
        val sc = new SparkContext(conf)
        val hiveContext = new HiveContext(sc)
        hiveContext.sql("use spark");
            hiveContext.sql("drop table if exists sales");
            hiveContext.sql("create table if not exists sales (riqi string,leibie string,jine Int) "
                    + "row format delimited fields terminated by '	'");
            hiveContext.sql("load data local inpath '/root/test/sales' into table sales");
            
            /**
             * 开窗函数格式:
             * 【 rou_number() over (partitin by XXX order by XXX) 】
             */
            val result = hiveContext.sql("select riqi,leibie,jine "
                                + "from ("
                                    + "select riqi,leibie,jine,"
                                    + "row_number() over (partition by leibie order by jine desc) rank "
                                    + "from sales) t "
                            + "where t.rank<=3");
            result.show();
        sc.stop()
    }

    参考:
    Spark

  • 相关阅读:
    leetcode 309. Best Time to Buy and Sell Stock with Cooldown
    leetcode 714. Best Time to Buy and Sell Stock with Transaction Fee
    leetcode 32. Longest Valid Parentheses
    leetcode 224. Basic Calculator
    leetcode 540. Single Element in a Sorted Array
    leetcode 109. Convert Sorted List to Binary Search Tree
    leetcode 3. Longest Substring Without Repeating Characters
    leetcode 84. Largest Rectangle in Histogram
    leetcode 338. Counting Bits
    git教程之回到过去,版本对比
  • 原文地址:https://www.cnblogs.com/cac2020/p/10717909.html
Copyright © 2011-2022 走看看