zoukankan      html  css  js  c++  java
  • Spark 自定义函数(udf,udaf)

    Spark 版本 2.3

    文中测试数据(json)

    {"name":"lillcol", "age":24,"ip":"192.168.0.8"}
    {"name":"adson", "age":100,"ip":"192.168.255.1"}
    {"name":"wuli", "age":39,"ip":"192.143.255.1"}
    {"name":"gu", "age":20,"ip":"192.168.255.1"}
    {"name":"ason", "age":15,"ip":"243.168.255.9"}
    {"name":"tianba", "age":1,"ip":"108.168.255.1"}
    {"name":"clearlove", "age":25,"ip":"222.168.255.110"}
    {"name":"clearlove", "age":30,"ip":"222.168.255.110"}
    

    用户自定义udf

    自定义udf的方式有两种

    1. SQLContext.udf.register()
    2. 创建UserDefinedFunction

    这两种个方式 使用范围不一样

    package com.test.spark
    
    import org.apache.spark.sql.expressions.UserDefinedFunction
    import org.apache.spark.sql.functions.udf
    import org.apache.spark.sql.{Dataset, Row, SparkSession}
    
    /**
      * @author Administrator
      *         2019/7/22-14:04
      *
      */
    object TestUdf {
    
      val spark = SparkSession
        .builder()
        .appName("TestCreateDataset")
        .config("spark.some.config.option", "some-value")
        .master("local")
        .enableHiveSupport()
        .getOrCreate()
      val sQLContext = spark.sqlContext
    
      import spark.implicits._
    
    
      def main(args: Array[String]): Unit = {
        testudf
      }
    
      def testudf() = {
        val iptoLong: UserDefinedFunction = getIpToLong()
        val ds: Dataset[Row] = spark.read.json("D:\DATA-LG\PUBLIC\TYGQ\INF\testJson")
        ds.createOrReplaceTempView("table1")
        sQLContext.udf.register("addName", sqlUdf(_: String)) //addName 只能在SQL里面用  不能在DSL 里面用
        //1.SQL
        sQLContext.sql("select *,addName(name) as nameAddName  from table1")
          .show()
        //2.DSL
        val addName: UserDefinedFunction = udf((str: String) => ("ip: " + str))
        ds.select($"*", addName($"ip").as("ipAddName"))
          .show()
    
        //如果自定义函数相对复杂,可以将它分离出去 如iptoLong
        ds.select($"*", iptoLong($"ip").as("iptoLong"))
          .show()
      }
    
      def sqlUdf(name: String): String = {
        "name:" + name
      }
    
      /**
        * 用户自定义 UDF 函数
        *
        * @return
        */
      def getIpToLong(): UserDefinedFunction = {
        val ipToLong: UserDefinedFunction = udf((ip: String) => {
          val arr: Array[String] = ip.replace(" ", "").replace(""", "").split("\.")
          var result: Long = 0
          var ipl: Long = 0
          if (arr.length == 4) {
            for (i <- 0 to 3) {
              ipl = arr(i).toLong
              result |= ipl << ((3 - i) << 3)
            }
          } else {
            result = -1
          }
          result
        })
        ipToLong
      }
    
    
    }
    
    输出结果
    +---+---------------+---------+--------------+
    |age|             ip|     name|   nameAddName|
    +---+---------------+---------+--------------+
    | 24|    192.168.0.8|  lillcol|  name:lillcol|
    |100|  192.168.255.1|    adson|    name:adson|
    | 39|  192.143.255.1|     wuli|     name:wuli|
    | 20|  192.168.255.1|       gu|       name:gu|
    | 15|  243.168.255.9|     ason|     name:ason|
    |  1|  108.168.255.1|   tianba|   name:tianba|
    | 25|222.168.255.110|clearlove|name:clearlove|
    | 30|222.168.255.110|clearlove|name:clearlove|
    +---+---------------+---------+--------------+
    
    +---+---------------+---------+-------------------+
    |age|             ip|     name|          ipAddName|
    +---+---------------+---------+-------------------+
    | 24|    192.168.0.8|  lillcol|    ip: 192.168.0.8|
    |100|  192.168.255.1|    adson|  ip: 192.168.255.1|
    | 39|  192.143.255.1|     wuli|  ip: 192.143.255.1|
    | 20|  192.168.255.1|       gu|  ip: 192.168.255.1|
    | 15|  243.168.255.9|     ason|  ip: 243.168.255.9|
    |  1|  108.168.255.1|   tianba|  ip: 108.168.255.1|
    | 25|222.168.255.110|clearlove|ip: 222.168.255.110|
    | 30|222.168.255.110|clearlove|ip: 222.168.255.110|
    +---+---------------+---------+-------------------+
    
    +---+---------------+---------+----------+
    |age|             ip|     name|  iptoLong|
    +---+---------------+---------+----------+
    | 24|    192.168.0.8|  lillcol|3232235528|
    |100|  192.168.255.1|    adson|3232300801|
    | 39|  192.143.255.1|     wuli|3230662401|
    | 20|  192.168.255.1|       gu|3232300801|
    | 15|  243.168.255.9|     ason|4087938825|
    |  1|  108.168.255.1|   tianba|1823014657|
    | 25|222.168.255.110|clearlove|3735617390|
    | 30|222.168.255.110|clearlove|3735617390|
    +---+---------------+---------+----------+
    

    用户自定义 UDAF 函数(即聚合函数)

    弱类型用户自定义聚合函数

    通过继承UserDefinedAggregateFunction

    package com.test.spark
    
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.{Dataset, Row, SparkSession}
    
    /**
      * @author lillcol
      *         2019/7/22-15:09
      *         弱类型用户自定义聚合函数
      */
    object TestUDAF extends UserDefinedAggregateFunction {
      // 聚合函数输入参数的数据类型
      // :: 用于的是向队列的头部追加数据,产生新的列表,Nil 是一个空的 List,定义为 List[Nothing]
      override def inputSchema: StructType = StructType(StructField("age", IntegerType) :: Nil)
    
      //等效于
      //  override def inputSchema: StructType=new StructType() .add("age", IntegerType).add("name", StringType)
    
      // 聚合缓冲区中值的数据类型
      override def bufferSchema: StructType = {
        StructType(StructField("sum", IntegerType) :: StructField("count", IntegerType) :: Nil)
      }
    
      // UserDefinedAggregateFunction返回值的数据类型。
      override def dataType: DataType = DoubleType
    
      // 如果这个函数是确定的,即给定相同的输入,总是返回相同的输出。
      override def deterministic: Boolean = true
    
      //  初始化给定的聚合缓冲区,即聚合缓冲区的零值。
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        // sum,  总的年龄
        buffer(0) = 0
        // count, 人数
        buffer(1) = 0
      }
    
      //  使用来自输入的新输入数据更新给定的聚合缓冲区。
      // 每个输入行调用一次。(同一分区)
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer(0) = buffer.getInt(0) + input.getInt(0) //年龄 叠加
        buffer(1) = buffer.getInt(1) + 1 //人数叠加
      }
    
      //  合并两个聚合缓冲区并将更新后的缓冲区值存储回buffer1。
      // 当我们将两个部分聚合的数据合并在一起时,就会调用这个函数。(多个分区)
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0) //年龄 叠加
        buffer1(1) = buffer1.getInt(1) + buffer2.getInt(1) //人数叠加
      }
    
      override def evaluate(buffer: Row): Any = {
        buffer.getInt(0).toDouble / buffer.getInt(1)
      }
    
      val spark = SparkSession
        .builder()
        .appName("Spark SQL basic example")
        // .config("spark.some.config.option", "some-value")
        .master("local[*]") // 本地测试
        .getOrCreate()
    
      import spark.implicits._
    
      def main(args: Array[String]): Unit = {
        spark.udf.register("myAvg", TestUDAF)
        val ds: Dataset[Row] = spark.read.json("D:\DATA-LG\PUBLIC\TYGQ\INF\testJson")
        ds.createOrReplaceTempView("table1")
        //SQL
        spark.sql("select myAvg(age) as avgAge from table1")
          .show()
    
        //DSL
        val myavg = TestUDAF
        ds.select(TestUDAF($"age").as("avgAge"))
          .show()
      }
    }
    
    输出结果:
    +------+
    |avgAge|
    +------+
    | 31.75|
    +------+
    
    +------+
    |avgAge|
    +------+
    | 31.75|
    +------+
    

    强类型用户自定义聚合函数

    通过继承Aggregator(是org.apache.spark.sql.expressions 下的 不要引错包了)

    package com.test.spark
    
    import org.apache.spark.sql.{Dataset, Encoder, Encoders, SparkSession}
    import org.apache.spark.sql.expressions._
    
    /**
      * @author Administrator
      *         2019/7/22-16:07
      *
      */
    // 既然是强类型,可能有 case 类
    case class Person(name: String, age: Double, ip: String)
    
    case class Average(var sum: Double, var count: Double)
    
    object MyAverage extends Aggregator[Person, Average, Double] {
      //  此聚合的值为零。应该满足任意b + 0 = b的性质。
      //  定义一个数据结构,保存工资总数和工资总个数,初始都为0
      override def zero: Average = {
        Average(0, 0)
      }
    
      //  将两个值组合起来生成一个新值。为了提高性能,函数可以修改b并返回它,而不是为b构造新的对象。
      //  相同 Execute 间的数据合并(同一分区)
      override def reduce(b: Average, a: Person): Average = {
        b.sum += a.age
        b.count += 1
        b
      }
    
      // 合并两个中间值。
      // 聚合不同 Execute 的结果(不同分区)
      override def merge(b1: Average, b2: Average): Average = {
        b1.sum += b2.sum
        b1.count += b2.count
        b1
      }
    
      // 计算最终结果
      override def finish(reduction: Average): Double = {
        reduction.sum.toInt / reduction.count
      }
    
      //  为中间值类型指定“编码器”。
      override def bufferEncoder: Encoder[Average] = Encoders.product
    
      //  为最终输出值类型指定“编码器”。
      override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    
      val spark = SparkSession
        .builder()
        .appName("Spark SQL basic example")
        // .config("spark.some.config.option", "some-value")
        .master("local[*]") // 本地测试
        .getOrCreate()
    
      import spark.implicits._
    
      def main(args: Array[String]): Unit = {
        val ds: Dataset[Person] = spark.read.json("D:\DATA-LG\PUBLIC\TYGQ\INF\testJson").as[Person]
        ds.show()
    
        val avgAge = MyAverage.toColumn/*.name("avgAge")*///指定该列的别名为avgAge
        ds.select(avgAge)//执行avgAge.as("columnName") 汇报org.apache.spark.sql.AnalysisException错误  别名只能在上面指定(目前测试是这样)
          .show()
      }
    }
    
    输出结果:
    +---+---------------+---------+
    |age|             ip|     name|
    +---+---------------+---------+
    | 24|    192.168.0.8|  lillcol|
    |100|  192.168.255.1|    adson|
    | 39|  192.143.255.1|     wuli|
    | 20|  192.168.255.1|       gu|
    | 15|  243.168.255.9|     ason|
    |  1|  108.168.255.1|   tianba|
    | 25|222.168.255.110|clearlove|
    | 30|222.168.255.110|clearlove|
    +---+---------------+---------+
    
    +------+
    |avgAge|
    +------+
    | 31.75|
    +------+
    

    本文为原创文章,转载请注明出处!!!

  • 相关阅读:
    友盟上报 IOS
    UTF8编码
    Hill加密算法
    Base64编码
    Logistic Regression 算法向量化实现及心得
    152. Maximum Product Subarray(中等, 神奇的 swap)
    216. Combination Sum III(medium, backtrack, 本类问题做的最快的一次)
    77. Combinations(medium, backtrack, 重要, 弄了1小时)
    47. Permutations II(medium, backtrack, 重要, 条件较难思考)
    3.5 find() 判断是否存在某元素
  • 原文地址:https://www.cnblogs.com/lillcol/p/11229044.html
Copyright © 2011-2022 走看看