zoukankan      html  css  js  c++  java
  • 2Spark学习笔记2

    SparkSQL

    SparkSQL概述

    SparkSQL核心编程

    package com.lotuslaw.spark.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    
    
    /**
     * @author: lotuslaw
     * @version: V1.0
     * @package: com.lotuslaw.spark.sql
     * @create: 2021-12-02 20:05
     * @description:
     */
    object Spark_SparkSQL_Basic {
    
      def main(args: Array[String]): Unit = {
        // TODO: 创建SparkSQL的运行环境
        val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
        val spark = SparkSession.builder().config(sparkConf).getOrCreate()
        import spark.implicits._
    
        // TODO: 创建运行环境
        val df = spark.read.json("input/user.json")
    //    df.show()
    
        df.createOrReplaceTempView("user")
    //    spark.sql("select * from user").show
    
        // 在使用DataFrame时,如果涉及到转换操作,需要引入转换规则
    //    df.select("age", "username").show
    
    //    df.select($"age" + 1).show
    
        // DataSet
        val seq = Seq(1, 2, 3, 4)
        val ds = seq.toDS()
    //    ds.show
    
        val rdd = spark.sparkContext.makeRDD(List((1, "zhangsan", 30), (2, "lisi", 40)))
        val df2 = rdd.toDF("id", "name", "age")
        val rowRDD = df2.rdd
    
        val ds2 = df2.as[User]
        val df3 = ds2.toDF()
    
        val ds3 = rdd.map {
          case (id, name, age) => {
            User(id, name, age)
          }
        }.toDS()
    
        val userRDD = ds3.rdd
    
    
        // 关闭环境
        spark.close()
    
      }
    
      case class User(id: Int, name: String, age: Int)
    }
    

    package com.lotuslaw.spark.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    
    
    /**
     * @author: lotuslaw
     * @version: V1.0
     * @package: com.lotuslaw.spark.sql
     * @create: 2021-12-02 20:05
     * @description:
     */
    object Spark_SparkSQL_UDF {
    
      def main(args: Array[String]): Unit = {
        // TODO: 创建SparkSQL的运行环境
        val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
        val spark = SparkSession.builder().config(sparkConf).getOrCreate()
        import spark.implicits._
    
        val df = spark.read.json("input/user.json")
        df.createOrReplaceTempView("user")
    
        spark.udf.register("prefixName", (name:String) => {
          "Name" + name
        })
    
        spark.sql("select age, prefixName(username) from user").show()
    
    
        spark.close()
      }
    }
    

    package com.lotuslaw.spark.sql
    
    import org.apache.parquet.filter2.predicate.Operators.UserDefined
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.{Row, SparkSession}
    import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
    
    
    /**
     * @author: lotuslaw
     * @version: V1.0
     * @package: com.lotuslaw.spark.sql
     * @create: 2021-12-02 20:05
     * @description:
     */
    object Spark_SparkSQL_UDAF {
    
      def main(args: Array[String]): Unit = {
        // TODO: 创建SparkSQL的运行环境
        val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
        val spark = SparkSession.builder().config(sparkConf).getOrCreate()
        import spark.implicits._
    
        val df = spark.read.json("input/user.json")
    
        df.createOrReplaceTempView("user")
    
        spark.udf.register("avgAge", new MyAvgUDAF())
    
        spark.sql("select avgAge(age) from user").show
    
        spark.close()
      }
    
      /*
         自定义聚合函数类:计算年龄的平均值
         1. 继承UserDefinedAggregateFunction
         2. 重写方法(8)
         */
      class MyAvgUDAF extends UserDefinedAggregateFunction{
    
        // 输入数据的结构
        override def inputSchema: StructType = {
          StructType(
            Array(
              StructField("age", LongType)
            )
          )
        }
    
        // 缓冲区数据的结构
        override def bufferSchema: StructType = {
          StructType(
            Array(
              StructField("total", LongType),
              StructField("count", LongType)
            )
          )
        }
    
        // 函数计算结果的数据类型
        override def dataType: DataType = LongType
    
        // 函数的稳定性
        override def deterministic: Boolean = true
    
        // 缓冲区初始化
        override def initialize(buffer: MutableAggregationBuffer): Unit = {
          buffer.update(0, 0L)
          buffer.update(1, 0L)
        }
    
        // 根据输入的值更新缓冲区数据
        override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
          buffer.update(0, buffer.getLong(0) + input.getLong(0))
          buffer.update(1, buffer.getLong(1) + 1)
        }
    
        // 缓冲区数据合并
        override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
          buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
          buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
        }
    
        override def evaluate(buffer: Row): Any = {
          buffer.getLong(0) / buffer.getLong(1)
        }
      }
    }
    
    package com.lotuslaw.spark.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.expressions.Aggregator
    import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
    
    
    /**
     * @author: lotuslaw
     * @version: V1.0
     * @package: com.lotuslaw.spark.sql
     * @create: 2021-12-02 20:05
     * @description:
     */
    object Spark_SparkSQL_UDF1 {
    
      def main(args: Array[String]): Unit = {
        // TODO: 创建SparkSQL的运行环境
        val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
        val spark = SparkSession.builder().config(sparkConf).getOrCreate()
        import spark.implicits._
    
        val df = spark.read.json("input/user.json")
        df.createOrReplaceTempView("user")
    
        spark.udf.register("ageAvg", functions.udaf(new MyAvgUDAF()))
    
        spark.sql("select ageAvg(age) from user").show
    
        spark.close()
      }
    
      /*
         自定义聚合函数类:计算年龄的平均值
         1. 继承org.apache.spark.sql.expressions.Aggregator, 定义泛型
             IN : 输入的数据类型 Long
             BUF : 缓冲区的数据类型 Buff
             OUT : 输出的数据类型 Long
         2. 重写方法(6)
         */
      case class Buff(var total: Long, var count: Long)
      class MyAvgUDAF extends Aggregator[Long, Buff, Long] {
    
        // 缓冲区的初始化
        override def zero: Buff = {
          Buff(0L, 0L)
        }
    
        // 根据输入的数据更新缓冲区的数据
        override def reduce(b: Buff, a: Long): Buff = {
          b.total = b.total + a
          b.count = b.count + 1
          b
        }
    
        // 合并缓冲区
        override def merge(b1: Buff, b2: Buff): Buff = {
          b1.total = b1.total + b2.total
          b1.count = b1.count + b2.count
          b1
        }
    
        // 计算结果
        override def finish(reduction: Buff): Long = {
          reduction.total / reduction.count
        }
    
        // 缓冲区的编码操作
        override def bufferEncoder: Encoder[Buff] = Encoders.product
    
        override def outputEncoder: Encoder[Long] = Encoders.scalaLong
      }
    }
    
    package com.lotuslaw.spark.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.expressions.Aggregator
    import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
    
    
    /**
     * @author: lotuslaw
     * @version: V1.0
     * @package: com.lotuslaw.spark.sql
     * @create: 2021-12-02 20:05
     * @description:
     */
    object Spark_SparkSQL_UDF2 {
    
      def main(args: Array[String]): Unit = {
        // TODO: 创建SparkSQL的运行环境
        val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
        val spark = SparkSession.builder().config(sparkConf).getOrCreate()
        import spark.implicits._
    
        val df = spark.read.json("input/user.json")
        // 早期版本中,spark不能在sql中使用强类型UDAF操作
        // SQL & DSL
        // 早期的UDAF强类型聚合函数使用DSL语法操作
        val ds = df.as[User]
    
        // 将UDAF函数转换为查询的列对象
        val udafCol = new MyAvgUDAF().toColumn
    
        ds.select(udafCol).show
    
        spark.close()
      }
    
      /*
         自定义聚合函数类:计算年龄的平均值
         1. 继承org.apache.spark.sql.expressions.Aggregator, 定义泛型
             IN : 输入的数据类型 User
             BUF : 缓冲区的数据类型 Buff
             OUT : 输出的数据类型 Long
         2. 重写方法(6)
         */
      case class User(username: String, age:Long)
      case class Buff(var total: Long, var count: Long)
      class MyAvgUDAF extends Aggregator[User, Buff, Long]{
    
        // 缓冲区的初始化
        override def zero: Buff = {
          Buff(0L, 0L)
        }
    
        // 根据输入的数据更新缓冲区的数据
        override def reduce(b: Buff, a: User): Buff = {
          b.total = b.total + a.age
          b.count = b.count + 1
          b
        }
    
        // 合并缓冲区
        override def merge(b1: Buff, b2: Buff): Buff = {
          b1.total = b1.total + b2.total
          b1.count = b1.count + b2.count
          b1
        }
    
        // 计算结果
        override def finish(reduction: Buff): Long = {
          reduction.total / reduction.count
        }
    
        // 缓冲区的编码操作
        override def bufferEncoder: Encoder[Buff] = Encoders.product
    
        override def outputEncoder: Encoder[Long] = Encoders.scalaLong
      }
    }
    

    package com.lotuslaw.spark.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.expressions.Aggregator
    import org.apache.spark.sql.{Encoder, Encoders, SaveMode, SparkSession}
    
    
    /**
     * @author: lotuslaw
     * @version: V1.0
     * @package: com.lotuslaw.spark.sql
     * @create: 2021-12-02 20:05
     * @description:
     */
    object Spark_SparkSQL_JDBC {
    
      def main(args: Array[String]): Unit = {
        // TODO: 创建SparkSQL的运行环境
        val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
        val spark = SparkSession.builder().config(sparkConf).getOrCreate()
        import spark.implicits._
    
        // 读MySQL数据
        val df = spark.read
          .format("jdbc")
          .option("url", "jdbc:mysql://hadoop102:3306/test")
          .option("driver", "com.mysql.jdbc.Driver")
          .option("user", "root")
          .option("password", "********")
          .option("dbtable", "users")
          .load()
    
    //    df.show
    
        val df2 = spark.sparkContext.makeRDD(List((3, "wangwu", 30))).toDF("id", "name", "age")
        df2.write
          .format("jdbc")
          .option("url", "jdbc:mysql://hadoop102:3306/test")
          .option("driver", "com.mysql.jdbc.Driver")
          .option("user", "root")
          .option("password", "******")
          .option("dbtable", "users")
          .mode(SaveMode.Append)
          .save()
    
        spark.close()
      }
    }
    

    package com.lotuslaw.spark.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.expressions.Aggregator
    import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
    
    
    /**
     * @author: lotuslaw
     * @version: V1.0
     * @package: com.lotuslaw.spark.sql
     * @create: 2021-12-02 20:05
     * @description:
     */
    object Spark_SparkSQL_Hive {
    
      def main(args: Array[String]): Unit = {
    
        System.setProperty("HADOOP_USER_NAME", "root")
    
        // TODO: 创建SparkSQL的运行环境
        val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
        val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
        import spark.implicits._
    
        // 使用SparkSQL连接外置的Hive
        // 1. 拷贝Hive-size.xml文件到classpath下
        // 2. 启用Hive的支持
        // 3. 增加对应的依赖关系(包含MySQL驱动)
        spark.sql("show databases").show
    
        spark.close()
      }
    }
    

    SparkSQL项目实战

    package com.lotuslaw.spark.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.expressions.Aggregator
    import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
    
    
    /**
     * @author: lotuslaw
     * @version: V1.0
     * @package: com.lotuslaw.spark.sql
     * @create: 2021-12-02 20:05
     * @description:
     */
    object Spark_SparkSQL_Test {
    
      def main(args: Array[String]): Unit = {
    
        System.setProperty("HADOOP_USER_NAME", "lotuslaw")
        // TODO: 创建SparkSQL的运行环境
        val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
        val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
        import spark.implicits._
    
        spark.sql("use db_hive")
    
        // 准备数据
        spark.sql(
          """
            |CREATE TABLE IF NOT EXISTS `user_visit_action`(
            |  `date` string,
            |  `user_id` bigint,
            |  `session_id` string,
            |  `page_id` bigint,
            |  `action_time` string,
            |  `search_keyword` string,
            |  `click_category_id` bigint,
            |  `click_product_id` bigint,
            |  `order_category_ids` string,
            |  `order_product_ids` string,
            |  `pay_category_ids` string,
            |  `pay_product_ids` string,
            |  `city_id` bigint)
            |row format delimited fields terminated by '\t'
            |""".stripMargin
        )
    
        spark.sql(
          """
            |load data local inpath 'input/user_visit_action.txt' into table db_hive.user_visit_action
            |""".stripMargin
        )
    
        spark.sql(
          """
            |CREATE TABLE IF NOT EXISTS `product_info`(
            |  `product_id` bigint,
            |  `product_name` string,
            |  `extend_info` string)
            |row format delimited fields terminated by '\t'
            |""".stripMargin
        )
    
        spark.sql(
          """
            |load data local inpath 'input/product_info.txt' into table db_hive.product_info
            |""".stripMargin
        )
    
        spark.sql(
          """
            |CREATE TABLE IF NOT EXISTS `city_info`(
            |  `city_id` bigint,
            |  `city_name` string,
            |  `area` string)
            |row format delimited fields terminated by '\t'
            |""".stripMargin
        )
    
        spark.sql(
          """
            |load data local inpath 'input/city_info.txt' into table db_hive.city_info
            |""".stripMargin
        )
    
        spark.sql(
          """
            |select * from city_info
            |""".stripMargin).show
    
    
        spark.close()
      }
    }
    
    package com.lotuslaw.spark.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.SparkSession
    
    
    /**
     * @author: lotuslaw
     * @version: V1.0
     * @package: com.lotuslaw.spark.sql
     * @create: 2021-12-02 20:05
     * @description:
     */
    object Spark_SparkSQL_Test1 {
    
      def main(args: Array[String]): Unit = {
    
        System.setProperty("HADOOP_USER_NAME", "lotuslaw")
        // TODO: 创建SparkSQL的运行环境
        val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
        val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
    
        spark.sql("use db_hive")
    
        spark.sql(
          """
            |select
            |    *
            |from (
            |    select
            |        *,
            |        rank() over( partition by area order by clickCnt desc ) as rank
            |    from (
            |        select
            |           area,
            |           product_name,
            |           count(*) as clickCnt
            |        from (
            |            select
            |               a.*,
            |               p.product_name,
            |               c.area,
            |               c.city_name
            |            from user_visit_action a
            |            join product_info p on a.click_product_id = p.product_id
            |            join city_info c on a.city_id = c.city_id
            |            where a.click_product_id > -1
            |        ) t1 group by area, product_name
            |    ) t2
            |) t3 where rank <= 3
                """.stripMargin).show
    
    
        spark.close()
      }
    }
    
    package com.lotuslaw.spark.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
    import org.apache.spark.sql.expressions.Aggregator
    
    import scala.collection.mutable
    import scala.collection.mutable.ListBuffer
    
    
    /**
     * @author: lotuslaw
     * @version: V1.0
     * @package: com.lotuslaw.spark.sql
     * @create: 2021-12-02 20:05
     * @description:
     */
    object Spark_SparkSQL_Test2 {
    
      def main(args: Array[String]): Unit = {
        System.setProperty("HADOOP_USER_NAME", "lotuslaw")
    
        val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
        val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
    
        spark.sql("use db_hive")
    
        // 查询基本数据
        spark.sql(
          """
            |  select
            |     a.*,
            |     p.product_name,
            |     c.area,
            |     c.city_name
            |  from user_visit_action a
            |  join product_info p on a.click_product_id = p.product_id
            |  join city_info c on a.city_id = c.city_id
            |  where a.click_product_id > -1
                """.stripMargin).createOrReplaceTempView("t1")
    
        // 根据区域,商品进行数据聚合
        // 操作group by area, product_name内的数据
        spark.udf.register("cityRemark", functions.udaf(new CityRemarkUDAF()))
        spark.sql(
          """
            |  select
            |     area,
            |     product_name,
            |     count(*) as clickCnt,
            |     cityRemark(city_name) as city_remark
            |  from t1 group by area, product_name
                """.stripMargin).createOrReplaceTempView("t2")
    
        // 区域内对点击数量进行排行
        spark.sql(
          """
            |  select
            |      *,
            |      rank() over( partition by area order by clickCnt desc ) as rank
            |  from t2
                """.stripMargin).createOrReplaceTempView("t3")
    
        // 取前3名
        spark.sql(
          """
            | select
            |     *
            | from t3 where rank <= 3
                """.stripMargin).show(false)
    
        spark.close()
      }
    
      case class Buffer(var total: Long, var cityMap: mutable.Map[String, Long])
    
      // 自定义聚合函数:实现城市备注功能
      // 1. 继承Aggregator, 定义泛型
      //    IN : 城市名称
      //    BUF : Buffer =>【总点击数量,Map[(city, cnt), (city, cnt)]】
      //    OUT : 备注信息
      // 2. 重写方法(6)
      class CityRemarkUDAF extends Aggregator[String, Buffer, String] {
        // 缓冲区初始化
        override def zero: Buffer = {
          Buffer(0, mutable.Map[String, Long]())
        }
    
        // 更新缓冲区数据
        override def reduce(buff: Buffer, city: String): Buffer = {
          buff.total += 1
          val newCount = buff.cityMap.getOrElse(city, 0L) + 1
          buff.cityMap.update(city, newCount)
          buff
        }
    
        // 合并缓冲区数据
        override def merge(buff1: Buffer, buff2: Buffer): Buffer = {
          buff1.total += buff2.total
    
          val map1 = buff1.cityMap
          val map2 = buff2.cityMap
    
          map2.foreach {
            case (city, cnt) => {
              val newCount = map1.getOrElse(city, 0L) + cnt
              map1.update(city, newCount)
            }
          }
          buff1.cityMap = map1
          buff1
        }
    
        // 将统计的结果生成字符串信息
        override def finish(buff: Buffer): String = {
          val remarkList = ListBuffer[String]()
    
          val totalcnt = buff.total
          val cityMap = buff.cityMap
    
          // 降序排列
          val cityCntList = cityMap.toList.sortWith(
            (left, right) => {
              left._2 > right._2
            }
          ).take(2)
    
          val hasMore = cityMap.size > 2
          var rsum = 0L
          cityCntList.foreach {
            case (city, cnt) => {
              val r = cnt * 100 / totalcnt
              remarkList.append(s"${city} ${r}%")
              rsum += r
            }
          }
          if (hasMore) {
            remarkList.append(s"其他 ${100 - rsum}%")
          }
    
          remarkList.mkString(", ")
        }
    
        override def bufferEncoder: Encoder[Buffer] = Encoders.product
    
        override def outputEncoder: Encoder[String] = Encoders.STRING
      }
    
    }
    
  • 相关阅读:
    Python成长笔记
    Python成长笔记
    Python成长笔记
    Python成长笔记
    Python成长笔记
    Python成长笔记
    Python成长笔记
    Python成长笔记
    Python成长笔记
    解决Jenkins生成测试报告的问题
  • 原文地址:https://www.cnblogs.com/lotuslaw/p/15640053.html
Copyright © 2011-2022 走看看