zoukankan      html  css  js  c++  java
  • sparksql系列(六) SparkSql中UDF、UDAF、UDTF

    RDD没有可以这种可以注册的方法。

    在使用sparksql过程中发现UDF还是有点用的所以,还是单独写一篇博客记录一下。

    UDF=》一个输入一个输出。相当于map

    UDAF=》多个输入一个输出。相当于reduce

    UDTF=》一个输入多个输出。相当于flatMap。(需要hive环境,暂时未测试)

    UDF

            其实就是在sql语句中注册函数,不要想得太难了。给大家写一个case when的语句

            import java.util.Arrays

            import org.apache.spark.SparkConf
            import org.apache.spark.api.java.JavaSparkContext
            import org.apache.spark.sql.{ DataFrame, Row, SparkSession, functions }
            import org.apache.spark.sql.functions.{ col, desc, length, row_number, trim, when }
            import org.apache.spark.sql.functions.{ countDistinct, sum, count, avg }
            import org.apache.spark.sql.functions.concat
            import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
            import org.apache.spark.sql.expressions.Window
            import org.apache.spark.storage.StorageLevel
            import org.apache.spark.sql.SaveMode
            import java.util.ArrayList

            object WordCount {

                    def main(args: Array[String]): Unit = {
                            val sparkSession = SparkSession.builder().master("local").getOrCreate()
                            val javasc = new JavaSparkContext(sparkSession.sparkContext)

                            val nameRDD1 = javasc.parallelize(Arrays.asList("{'id':'7'}", "{'id':'8'}",
                                    "{'id':'9'}","{'id':'10'}"));
                            val nameRDD1df = sparkSession.read.json(nameRDD1)

                            nameRDD1df.createTempView("idList")
            
                            sparkSession.udf.register("idParse",(str:String)=>{//注册一个函数,实现case when的函数
                                    str match{
                                            case "7" => "id7"
                                            case "8" => "id8"
                                            case "9" => "id9"
                                            case _=>"others"
                                    }
                            })
                            val data = sparkSession.sql("select idParse(id) from idList").show(100)
                    }
            }

            以上是UDF的sql用法,下面介绍data frame用法        

            import java.util.Arrays
            import org.apache.spark.SparkConf
            import org.apache.spark.api.java.JavaSparkContext
            import org.apache.spark.sql.{ DataFrame, Row, SparkSession, functions }
            import org.apache.spark.sql.functions.{ col,udf, desc, length, row_number, trim, when }
            import org.apache.spark.sql.functions.{ countDistinct, sum, count, avg }
            import org.apache.spark.sql.functions.concat
            import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
            import org.apache.spark.sql.expressions.Window
            import org.apache.spark.storage.StorageLevel
            import org.apache.spark.sql.SaveMode
            import java.util.ArrayList

            object WordCount {

                    def myUdf(value:String): String ={
                            println(value)
                            value+"|"
                    }
                    def main(args: Array[String]): Unit = {
                            val sparkSession = SparkSession.builder().master("local").getOrCreate()
                            val javasc = new JavaSparkContext(sparkSession.sparkContext)

                            val nameRDD1 = javasc.parallelize(Arrays.asList("{'id':'7'}", "{'id':'8'}","{'id':'9'}","{'id':'10'}"));
                            val fun = udf(myUdf _ )
                            val nameRDD1df = sparkSession.read.json(nameRDD1)
                                    .select(fun(col("id")) as "id").show(100)

                    }
            }

    UDAF

            import java.util.Arrays

            import org.apache.spark.SparkConf
            import org.apache.spark.api.java.JavaSparkContext
            import org.apache.spark.sql.{ DataFrame, Row, SparkSession, functions }
            import org.apache.spark.sql.functions.{ col, desc, length, row_number, trim, when }
            import org.apache.spark.sql.functions.{ countDistinct, sum, count, avg }
            import org.apache.spark.sql.functions.concat
            import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
            import org.apache.spark.sql.expressions.Window
            import org.apache.spark.storage.StorageLevel
            import org.apache.spark.sql.SaveMode
            import java.util.ArrayList
            import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
            import org.apache.spark.sql.expressions.MutableAggregationBuffer
            import org.apache.spark.sql.types.IntegerType
            import org.apache.spark.sql.types.DataType

            class MyMax extends UserDefinedAggregateFunction{
                    //定义输入数据的类型,两种写法都可以
                    //override def inputSchema: StructType = StructType(Array(StructField("input", IntegerType, true)))
                    override def inputSchema: StructType = StructType(StructField("input", IntegerType) :: Nil)
                    //定义聚合过程中所处理的数据类型
                    // override def bufferSchema: StructType = StructType(Array(StructField("cache", IntegerType, true)))
                    override def bufferSchema: StructType = StructType(StructField("max", IntegerType) :: Nil)
                    //定义输入数据的类型
                    override def dataType: DataType = IntegerType
                    //规定一致性
                    override def deterministic: Boolean = true
                    //在聚合之前,每组数据的初始化操作
                    override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0}
                    //每组数据中,当新的值进来的时候,如何进行聚合值的计算
                    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
                            if(input.getInt(0)> buffer.getInt(0))
                                    buffer(0)=input.getInt(0)
                    }
                    //合并各个分组的结果
                    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
                            if(buffer2.getInt(0)> buffer1.getInt(0)){
                                    buffer1(0)=buffer2.getInt(0)
                            }
                    }
                    //返回最终结果
                    override def evaluate(buffer: Row): Any = {buffer.getInt(0)}
            }


            class MyAvg extends UserDefinedAggregateFunction{
                    //输入数据的类型
                    override def inputSchema: StructType = StructType(StructField("input", IntegerType) :: Nil)
                    //中间结果数据的类型
                    override def bufferSchema: StructType = StructType(
                            StructField("sum", IntegerType) :: StructField("count", IntegerType) :: Nil)
                    //定义输入数据的类型
                    override def dataType: DataType = IntegerType
                    //规定一致性
                    override def deterministic: Boolean = true
                    //初始化操作
                    override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0;buffer(1) =0;}

                    //map端reduce,所有数据必须过这一段代码
                    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
                            buffer.update(0, buffer.getInt(0)+input.getInt(0))
                            buffer.update(1, buffer.getInt(1)+1)
                    }
                    //reduce数据,update里面Row,没有第二个字段,这时候就有了第二个字段
                    override def merge(buffer: MutableAggregationBuffer, input: Row): Unit = {
                            buffer.update(0, buffer.getInt(0)+input.getInt(0))
                            buffer.update(1, buffer.getInt(1)+input.getInt(1))
                    }
                    //返回最终结果
                    override def evaluate(finalVaue: Row): Int = {finalVaue.getInt(0)/finalVaue.getInt(1)}
                    }

                    object WordCount {

                            def main(args: Array[String]): Unit = {
                                    val sparkSession = SparkSession.builder().master("local").getOrCreate()
                                    val javasc = new JavaSparkContext(sparkSession.sparkContext)

                                    val nameRDD1 = javasc.parallelize(Arrays.asList("{'id':'7'}"));
                                    val nameRDD1df = sparkSession.read.json(nameRDD1)
                                    val nameRDD2 = javasc.parallelize(Arrays.asList( "{'id':'8'}"));
                                    val nameRDD2df = sparkSession.read.json(nameRDD2)
                                    val nameRDD3 = javasc.parallelize(Arrays.asList("{'id':'9'}"));
                                    val nameRDD3df = sparkSession.read.json(nameRDD3)
                                    val nameRDD4 = javasc.parallelize(Arrays.asList("{'id':'10'}"));
                                    val nameRDD4df = sparkSession.read.json(nameRDD4)

                                    nameRDD1df.union(nameRDD2df).union(nameRDD3df).union(nameRDD4df).registerTempTable("idList")

                                    // sparkSession.udf.register("myMax",new MyMax)
                                    sparkSession.udf.register("myAvg",new MyAvg)

                                    val data = sparkSession.sql("select myAvg(id) from idList").show(100)


                    }
            }

    UDTF 暂时没测试,家里没有hive环境

           import java.util.Arrays

           import org.apache.spark.SparkConf
           import org.apache.spark.api.java.JavaSparkContext
           import org.apache.spark.sql.{ DataFrame, Row, SparkSession, functions }
           import org.apache.spark.sql.functions.{ col, desc, length, row_number, trim, when }
           import org.apache.spark.sql.functions.{ countDistinct, sum, count, avg }
           import org.apache.spark.sql.functions.concat
           import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
           import org.apache.spark.sql.expressions.Window
           import org.apache.spark.storage.StorageLevel
           import org.apache.spark.sql.SaveMode
           import java.util.ArrayList
           import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
           import org.apache.spark.sql.expressions.MutableAggregationBuffer
           import org.apache.spark.sql.types.IntegerType
           import org.apache.spark.sql.types.DataType
           import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
           import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
           import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
           import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory
           import org.apache.hadoop.hive.ql.exec.UDFArgumentException
           import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException
           import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory

           class MyFloatMap extends GenericUDTF{
                  override def close(): Unit = {}
                  //这个方法的作用:1.输入参数校验 2. 输出列定义,可以多于1列,相当于可以生成多行多列数据
                  override def initialize(args:Array[ObjectInspector]): StructObjectInspector = {
                         if (args.length != 1) {
                                throw new UDFArgumentLengthException("UserDefinedUDTF takes only one argument")
                         }
                         if (args(0).getCategory() != ObjectInspector.Category.PRIMITIVE) {
                                throw new UDFArgumentException("UserDefinedUDTF takes string as a parameter")
                         }

                         val fieldNames = new java.util.ArrayList[String]
                         val fieldOIs = new java.util.ArrayList[ObjectInspector]

                         //这里定义的是输出列默认字段名称
                         fieldNames.add("col1")
                         //这里定义的是输出列字段类型
                         fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)

                         ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
                  }

                  //这是处理数据的方法,入参数组里只有1行数据,即每次调用process方法只处理一行数据
                  override def process(args: Array[AnyRef]): Unit = {
                         //将字符串切分成单个字符的数组
                         val strLst = args(0).toString.split("")
                         for(i <- strLst){
                                var tmp:Array[String] = new Array[String](1)
                                tmp(0) = i
                                //调用forward方法,必须传字符串数组,即使只有一个元素
                                forward(tmp)
                         }
                  }
           }

           object WordCount {

                  def main(args: Array[String]): Unit = {
                         val sparkSession = SparkSession.builder().master("local").getOrCreate()
                         val javasc = new JavaSparkContext(sparkSession.sparkContext)

                         val nameRDD1 = javasc.parallelize(Arrays.asList("{'id':'7'}"));
                         val nameRDD1df = sparkSession.read.json(nameRDD1)

                         nameRDD1df.createOrReplaceTempView("idList")

                         sparkSession.sql("create temporary function myFloatMap as 'MyFloatMap'")

                         val data = sparkSession.sql("select myFloatMap(id) from idList").show(100)

                  }
           }

  • 相关阅读:
    国科大 高级人工智能 期末复习总结
    算法岗面试问题总结
    java如何判断溢出
    matrix67中适合程序员的例子
    java map
    tensorflow手写数字识别(有注释)
    epoch,iteration与batchsize的区别
    java中如何不自己写排序方法完成排序
    Kotlin实现《第一行代码》案例“酷欧天气”
    Kotlin入门第三课:数据类型
  • 原文地址:https://www.cnblogs.com/wuxiaolong4/p/11924172.html
Copyright © 2011-2022 走看看