zoukankan      html  css  js  c++  java
  • Spark操作MySQL,Hive并写入MySQL数据库

    最近一个项目,需要操作近70亿数据进行统计分析。如果存入MySQL,很难读取如此大的数据,即使使用搜索引擎,也是非常慢。经过调研决定借助我们公司大数据平台结合Spark技术完成这么大数据量的统计分析。

    为了后期方便开发人员开发,决定写了几个工具类,屏蔽对MySQL及Hive的操作代码,只需要关心业务代码的编写。

    工具类如下:

    一. Spark操作MySQL

    1. 根据sql语句获取Spark DataFrame:

      /**
       * 从MySql数据库中获取DateFrame
       *
       * @param spark     SparkSession
       * @param sql 查询SQL
       * @return DateFrame
       */
      def getDFFromMysql(spark: SparkSession, sql: String): DataFrame = {
        println(s"url:${mySqlConfig.url} user:${mySqlConfig.user} sql: ${sql}")
        spark.read.format("jdbc").option("url", mySqlConfig.url)
          .option("user", mySqlConfig.user)
          .option("password", mySqlConfig.password)
          .option("driver", "com.mysql.jdbc.Driver")
          .option("query", sql) .load()
      }

    2. 将Spark DataFrame 写入MySQL数据库表

      /**
       * 将结果写入Mysql
       * @param df DataFrame
       * @param mode SaveMode
       * @param tableName SaveMode
       */
      def writeIntoMySql(df: DataFrame, mode: SaveMode, tableName: String): Unit ={
        mode match {
          case SaveMode.Append => appendDataIntoMysql(df, tableName);
          case SaveMode.Overwrite => overwriteMysqlData(df, tableName);
          case _ => throw new Exception("目前只支持Append及Overwrite!")
        }
      }
      /**
       * 将数据集插入Mysql表
       * @param df DataFrame
       * @param mysqlTableName 表名:database_name.table_name
       * @return
       */
      def appendDataIntoMysql(df: DataFrame, mysqlTableName: String) = {
        df.write.mode(SaveMode.Append).jdbc(mySqlConfig.url, mysqlTableName, getMysqlProp)
      }
      /**
       * 将数据集插入Mysql表
       * @param df DataFrame
       * @param mysqlTableName 表名:database_name.table_name
       * @return
       */
      def overwriteMysqlData(df: DataFrame, mysqlTableName: String) = {
        //先清除Mysql表中数据
        truncateMysqlTable(mysqlTableName)
        //再往表中追加数据
        df.write.mode(SaveMode.Append).jdbc(mySqlConfig.url, mysqlTableName, getMysqlProp)
      }
      /**
       * 删除数据表
       * @param mysqlTableName
       * @return
       */
      def truncateMysqlTable(mysqlTableName: String): Boolean = {
        val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接
        val preparedStatement = conn.createStatement()
        try {
          preparedStatement.execute(s"truncate table $mysqlTableName")
        } catch {
          case e: Exception =>
            println(s"mysql truncateMysqlTable error:${ExceptionUtil.getExceptionStack(e)}")
            false
        } finally {
          preparedStatement.close()
          conn.close()
        }

     3. 根据条件删除MySQL表数据

      /**
        *  删除表中的数据
        * @param mysqlTableName
        * @param condition
        * @return
        */
      def deleteMysqlTableData(mysqlTableName: String, condition: String): Boolean = {
        val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接
        val preparedStatement = conn.createStatement()
        try {
          preparedStatement.execute(s"delete from $mysqlTableName where $condition")
        } catch {
          case e: Exception =>
            println(s"mysql deleteMysqlTable error:${ExceptionUtil.getExceptionStack(e)}")
            false
        } finally {
          preparedStatement.close()
          conn.close()
        }
      }

    4. 保存DataFrame 到 MySQL中,如果表不存在的话,会自动创建

    /**
        * 保存DataFrame 到 MySQL中,如果表不存在的话,会自动创建
        * @param tableName
        * @param resultDateFrame
        */
      def saveDFtoDBCreateTableIfNotExist(tableName: String, resultDateFrame: DataFrame) {
        //如果没有表,根据DataFrame建表
        createTableIfNotExist(tableName, resultDateFrame)
        //验证数据表字段和dataFrame字段个数和名称,顺序是否一致
        verifyFieldConsistency(tableName, resultDateFrame)
        //保存df
        saveDFtoDBUsePool(tableName, resultDateFrame)
      }
      /**
        * 如果数据表不存在,根据DataFrame的字段创建数据表,数据表字段顺序和dataFrame对应
        * 若DateFrame出现名为id的字段,将其设为数据库主键(int,自增,主键),其他字段会根据DataFrame的DataType类型来自动映射到MySQL中
        *
        * @param tableName 表名
        * @param df        dataFrame
        * @return
        */
      def createTableIfNotExist(tableName: String, df: DataFrame): AnyVal = {
        val con = MySQLPoolManager.getMysqlManager.getConnection
        val metaData = con.getMetaData
        val colResultSet = metaData.getColumns(null, "%", tableName, "%")
        //如果没有该表,创建数据表
        if (!colResultSet.next()) {
          //构建建表字符串
          val sb = new StringBuilder(s"CREATE TABLE `$tableName` (")
          df.schema.fields.foreach(x =>
            if (x.name.equalsIgnoreCase("id")) {
              sb.append(s"`${x.name}` int(255) NOT NULL AUTO_INCREMENT PRIMARY KEY,") //如果是字段名为id,设置主键,整形,自增
            } else {
              x.dataType match {
                case _: ByteType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,")
                case _: ShortType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,")
                case _: IntegerType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,")
                case _: LongType => sb.append(s"`${x.name}` bigint(100) DEFAULT NULL,")
                case _: BooleanType => sb.append(s"`${x.name}` tinyint DEFAULT NULL,")
                case _: FloatType => sb.append(s"`${x.name}` float(50) DEFAULT NULL,")
                case _: DoubleType => sb.append(s"`${x.name}` double(50) DEFAULT NULL,")
                case _: StringType => sb.append(s"`${x.name}` varchar(50) DEFAULT NULL,")
                case _: TimestampType => sb.append(s"`${x.name}` timestamp DEFAULT current_timestamp,")
                case _: DateType => sb.append(s"`${x.name}` date  DEFAULT NULL,")
                case _ => throw new RuntimeException(s"nonsupport ${x.dataType} !!!")
              }
            }
          )
          sb.append(") ENGINE=InnoDB DEFAULT CHARSET=utf8")
          val sql_createTable = sb.deleteCharAt(sb.lastIndexOf(',')).toString()
          println(sql_createTable)
          val statement = con.createStatement()
          statement.execute(sql_createTable)
        }
      }
      /**
        * 验证数据表和dataFrame字段个数,名称,顺序是否一致
        *
        * @param tableName 表名
        * @param df        dataFrame
        */
      def verifyFieldConsistency(tableName: String, df: DataFrame): Unit = {
        val con = MySQLPoolManager.getMysqlManager.getConnection
        val metaData = con.getMetaData
        val colResultSet = metaData.getColumns(null, "%", tableName, "%")
        colResultSet.last()
        val tableFiledNum = colResultSet.getRow
        val dfFiledNum = df.columns.length
        if (tableFiledNum != dfFiledNum) {
          throw new Exception(s"数据表和DataFrame字段个数不一致!!table--$tableFiledNum but dataFrame--$dfFiledNum")
        }
        for (i <- 1 to tableFiledNum) {
          colResultSet.absolute(i)
          val tableFileName = colResultSet.getString("COLUMN_NAME")
          val dfFiledName = df.columns.apply(i - 1)
          if (!tableFileName.equals(dfFiledName)) {
            throw new Exception(s"数据表和DataFrame字段名不一致!!table--'$tableFileName' but dataFrame--'$dfFiledName'")
          }
        }
        colResultSet.beforeFirst()
      }
    /**
        * 将DataFrame所有类型(除id外)转换为String后,通过c3p0的连接池方法,向mysql写入数据
        *
        * @param tableName       表名
        * @param resultDateFrame DataFrame
        */
      def saveDFtoDBUsePool(tableName: String, resultDateFrame: DataFrame) {
        val colNumbers = resultDateFrame.columns.length
        val sql = getInsertSql(tableName, colNumbers)
        val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType)
        resultDateFrame.foreachPartition(partitionRecords => {
          val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接
          val preparedStatement = conn.prepareStatement(sql)
          val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%") //通过连接获取表名对应数据表的元数据
          try {
            conn.setAutoCommit(false)
            partitionRecords.foreach(record => {
              //注意:setString方法从1开始,record.getString()方法从0开始
              for (i <- 1 to colNumbers) {
                val value = record.get(i - 1)
                val dateType = columnDataTypes(i - 1)
                if (value != null) { //如何值不为空,将类型转换为String
                  preparedStatement.setString(i, value.toString)
                  dateType match {
                    case _: ByteType => preparedStatement.setInt(i, record.getAs[Int](i - 1))
                    case _: ShortType => preparedStatement.setInt(i, record.getAs[Int](i - 1))
                    case _: IntegerType => preparedStatement.setInt(i, record.getAs[Int](i - 1))
                    case _: LongType => preparedStatement.setLong(i, record.getAs[Long](i - 1))
                    case _: BooleanType => preparedStatement.setBoolean(i, record.getAs[Boolean](i - 1))
                    case _: FloatType => preparedStatement.setFloat(i, record.getAs[Float](i - 1))
                    case _: DoubleType => preparedStatement.setDouble(i, record.getAs[Double](i - 1))
                    case _: StringType => preparedStatement.setString(i, record.getAs[String](i - 1))
                    case _: TimestampType => preparedStatement.setTimestamp(i, record.getAs[Timestamp](i - 1))
                    case _: DateType => preparedStatement.setDate(i, record.getAs[Date](i - 1))
                    case _ => throw new RuntimeException(s"nonsupport ${dateType} !!!")
                  }
                } else { //如果值为空,将值设为对应类型的空值
                  metaData.absolute(i)
                  preparedStatement.setNull(i, metaData.getInt("DATA_TYPE"))
                }
              }
              preparedStatement.addBatch()
            })
            preparedStatement.executeBatch()
            conn.commit()
          } catch {
            case e: Exception => println(s"@@ saveDFtoDBUsePool error: ${ExceptionUtil.getExceptionStack(e)}")
            // do some log
          } finally {
            preparedStatement.close()
            conn.close()
          }
        })
      }

     二、操作Spark

    1. 切换Spark环境

    定义环境Profile.scala

    /**
     * @descrption
     * scf
     * @author wangxuexing
     * @date 2019/12/23
     */
    object Profile extends Enumeration{
      type Profile = Value
      /**
       * 生产环境
       */
      val PROD = Value("prod")
      /**
       * 生产测试环境
       */
      val PROD_TEST = Value("prod_test")
      /**
       * 开发环境
       */
      val DEV = Value("dev")
    
      /**
       * 设置当前环境
       */
      val currentEvn = PROD
    }

    定义SparkUtil.scala

    import com.dmall.scf.Profile
    import com.dmall.scf.dto.{Env, MySqlConfig}
    import org.apache.spark.sql.{DataFrame, Encoder, SparkSession}
    
    import scala.collection.JavaConversions._
    
    /**
     * @descrption Spark工具类
     * scf
     * @author wangxuexing
     * @date 2019/12/23
     */
    object SparkUtils {
    //开发环境
    
    
    val DEV_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&useSSL=false"
    
    
    val DEV_USER = "user"
    
    
    val DEV_PASSWORD = "password"
    
    
    //生产测试环境
    
    
    val PROD_TEST_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false"
    
    
    val PROD_TEST_USER = "user"
    
    
    val PROD_TEST_PASSWORD = "password"
    
    
    //生产环境
    
    
    val PROD_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&useSSL=false"
    
    
    val PROD_USER = "user"
    
    
    val PROD_PASSWORD = "password"
    
    
      def env = Profile.currentEvn
    
      /**
       * 获取环境设置
       * @return
       */
      def getEnv: Env ={
        env match {
          case Profile.DEV => Env(MySqlConfig(DEV_URL, DEV_USER, DEV_PASSWORD), SparkUtils.getDevSparkSession)
          case Profile.PROD =>
            Env(MySqlConfig(PROD_URL,PROD_USER,PROD_PASSWORD), SparkUtils.getProdSparkSession)
          case Profile.PROD_TEST =>
            Env(MySqlConfig(PROD_TEST_URL, PROD_TEST_USER, PROD_TEST_PASSWORD), SparkUtils.getProdSparkSession)
          case _ => throw new Exception("无法获取环境")
        }
      }
    
      /**
       * 获取生产SparkSession
       * @return
       */
      def getProdSparkSession: SparkSession = {
        SparkSession
          .builder()
          .appName("scf")
          .enableHiveSupport()//激活hive支持
          .getOrCreate()
      }
    
      /**
       * 获取开发SparkSession
       * @return
       */
      def getDevSparkSession: SparkSession = {
        SparkSession
          .builder()
          .master("local[*]")
          .appName("local-1576939514234")
          .config("spark.sql.warehouse.dir", "C:\data\spark-ware")//不指定,默认C:dataprojectsparquet2dbsspark-warehouse
          .enableHiveSupport()//激活hive支持
          .getOrCreate();
      }
    
      /**
       * DataFrame 转 case class
       * @param df DataFrame
       * @tparam T case class
       * @return
       */
      def dataFrame2Bean[T: Encoder](df: DataFrame, clazz: Class[T]): List[T] = {
        val fieldNames = clazz.getDeclaredFields.map(f => f.getName).toList
        df.toDF(fieldNames: _*).as[T].collectAsList().toList
      }
    }

    三、定义Spark操作流程

    从MySQL或Hive读取数据->逻辑处理->写入MySQL

    1. 定义处理流程

    SparkAction.scala

    import com.dmall.scf.utils.{MySQLUtils, SparkUtils}
    import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
    
    /**
     * @descrption 定义Spark处理流程
     * @author wangxuexing
     * @date 2019/12/23
     */
    trait SparkAction[T] {
      /**
       * 定义流程
       */
      def execute(args: Array[String], spark: SparkSession)={
        //1. 前置处理
        preAction
        //2. 处理
        val df = action(spark, args)
        //3. 后置处理
        postAction(df)
      }
    
      /**
       * 前置处理
       * @return
       */
      def preAction() = {
        //无前置处理
      }
    
      /**
       * 处理
       * @param spark
       * @return
       */
      def action(spark: SparkSession, args: Array[String]) : DataFrame
    
      /**
       * 后置处理,比如保存结果到Mysql
       * @param df
       */
      def postAction(df: DataFrame)={
        //结果追加到scfc_supplier_run_field_value表
        MySQLUtils.writeIntoMySql(df, saveTable._1, saveTable._2)
      }
    
      /**
       * 保存mode及表名
       * @return
       */
      def saveTable: (SaveMode, String)
    }

    2. 实现流程

    KanbanAction.scala

    import com.dmall.scf.SparkAction
    import com.dmall.scf.dto.KanbanFieldValue
    import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
    import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
    
    import scala.collection.JavaConverters._
    
    /**
     * @descrption
     * scf-spark
     * @author wangxuexing
     * @date 2020/1/10
     */
    trait KanbanAction extends SparkAction[KanbanFieldValue] {
      /**
       * 获取datafram
       * @param resultList
       * @param spark
       * @return
       */
      def getDataFrame(resultList: List[KanbanFieldValue], spark: SparkSession): DataFrame= {
        //根据模式字符串生成模式schema
        val fields = List(StructField("company_id", LongType, nullable = false),
          StructField("statistics_date", StringType, nullable = false),
          StructField("field_id", LongType, nullable = false),
          StructField("field_type", StringType, nullable = false),
          StructField("field_value", StringType, nullable = false),
          StructField("other_value", StringType, nullable = false))
        val schema = StructType(fields)
        //将RDD的记录转换为行
        val rowRDD = resultList.map(x=>Row(x.companyId, x.statisticsDate, x.fieldId, x.fieldType, x.fieldValue, x.otherValue)).asJava
        //RDD转为DataFrame
        spark.createDataFrame(rowRDD, schema)
      }
      /**
       * 保存mode及表名
       *
       * @return
       */
      override def saveTable: (SaveMode, String) = (SaveMode.Append, "scfc_kanban_field_value")
    }

    3. 实现具体业务逻辑

    import com.dmall.scf.dto.{KanbanFieldValue, RegisteredMoney}
    import com.dmall.scf.utils.{DateUtils, MySQLUtils}
    import org.apache.spark.sql.{DataFrame, SparkSession}
    
    /**
     * @descrption
     * scf-spark 注册资本分布
     * @author wangxuexing
     * @date 2020/1/10
     */
    object RegMoneyDistributionAction extends KanbanAction{
      val CLASS_NAME = this.getClass.getSimpleName().filter(!_.equals('$'))
    
      val RANGE_50W = BigDecimal(50)
      val RANGE_100W = BigDecimal(100)
      val RANGE_500W = BigDecimal(500)
      val RANGE_1000W = BigDecimal(1000)
    
      /**
       * 处理
       *
       * @param spark
       * @return
       */
      override def action(spark: SparkSession, args: Array[String]): DataFrame = {
        import spark.implicits._
        if(args.length < 2){
          throw new Exception("请指定是当前年(值为1)还是去年(值为2):1|2")
        }
        val lastDay = DateUtils.addSomeDays(-1)
        val (starDate, endDate, filedId) = args(1) match {
          case "1" =>
            val startDate = DateUtils.isFirstDayOfYear match {
              case true => DateUtils.getFirstDateOfLastYear
              case false => DateUtils.getFirstDateOfCurrentYear
            }
    
            (startDate, DateUtils.formatNormalDateStr(lastDay), 44)
          case "2" =>
            val startDate = DateUtils.isFirstDayOfYear match {
              case true => DateUtils.getLast2YearFirstStr(DateUtils.YYYY_MM_DD)
              case false => DateUtils.getLastYearFirstStr(DateUtils.YYYY_MM_DD)
            }
            val endDate = DateUtils.isFirstDayOfYear match {
              case true => DateUtils.getLast2YearLastStr(DateUtils.YYYY_MM_DD)
              case false => DateUtils.getLastYearLastStr(DateUtils.YYYY_MM_DD)
            }
            (startDate, endDate, 45)
          case _ =>  throw new Exception("请传入正确的参数:是当前年(值为1)还是去年(值为2):1|2")
        }
    
        val sql = s"""SELECT
                          id,
                          IFNULL(registered_money, 0) registered_money
                        FROM
                          scfc_supplier_info
                        WHERE
                          `status` = 3
                        AND yn = 1"""
        val allDimension = MySQLUtils.getDFFromMysql(spark, sql)
        val beanList =  allDimension.map(x => RegisteredMoney(x.getLong(0), x.getDecimal(1)))
        //val filterList =  SparkUtils.dataFrame2Bean[RegisteredMoney](allDimension, classOf[RegisteredMoney])
        val hiveSql = s"""
                       SELECT DISTINCT(a.company_id) supplier_ids
                        FROM wumart2dmall.wm_ods_cx_supplier_card_info a
                        JOIN wumart2dmall.wm_ods_jrbl_loan_dkzhxx b ON a.card_code = b.gshkahao
                        WHERE a.audit_status = '2'
                          AND b.jiluztai = '0'
                          AND to_date(b.gxinshij)>= '${starDate}'
                          AND to_date(b.gxinshij)<= '${endDate}'"""
        println(hiveSql)
        val supplierIds = spark.sql(hiveSql).collect().map(_.getLong(0))
        val filterList = beanList.filter(x => supplierIds.contains(x.supplierId))
    
        val range1 =  spark.sparkContext.collectionAccumulator[Int]
        val range2 =  spark.sparkContext.collectionAccumulator[Int]
        val range3 =  spark.sparkContext.collectionAccumulator[Int]
        val range4 =  spark.sparkContext.collectionAccumulator[Int]
        val range5 =  spark.sparkContext.collectionAccumulator[Int]
        filterList.foreach(x => {
          if(RANGE_50W.compare(x.registeredMoney) >= 0){
            range1.add(1)
          } else if (RANGE_50W.compare(x.registeredMoney) < 0 && RANGE_100W.compare(x.registeredMoney) >= 0){
            range1.add(1)
          } else if (RANGE_100W.compare(x.registeredMoney) < 0 && RANGE_500W.compare(x.registeredMoney) >= 0){
            range2.add(1)
          } else if (RANGE_500W.compare(x.registeredMoney) < 0 && RANGE_1000W.compare(x.registeredMoney) >= 0){
            range3.add(1)
          } else if (RANGE_1000W.compare(x.registeredMoney) < 0){
            range4.add(1)
          }
        })
        val resultList = List(("50万元以下", range1.value.size()), ("50-100万元", range2.value.size()),
                              ("100-500万元", range3.value.size()),("500-1000万元", range4.value.size()),
                              ("1000万元以上", range5.value.size())).map(x => {
          KanbanFieldValue(1, lastDay, filedId, x._1, x._2.toString, "")
        })
    
        getDataFrame(resultList, spark)
      }
    }

     具体项目源码请参考:

    https://github.com/barrywang88/spark-tool

    https://gitee.com/barrywang/spark-tool

  • 相关阅读:
    BZOJ 1013: [JSOI2008]球形空间产生器sphere
    BZOJ 1012: [JSOI2008]最大数maxnumber
    BZOJ 1011: [HNOI2008]遥远的行星
    BZOJ 1008: [HNOI2008]越狱
    BZOJ 1007: [HNOI2008]水平可见直线
    BZOJ 1003: [ZJOI2006]物流运输
    Spark core 总结
    SparkRDD算子(transformations算子和actions算子)
    SparkRDD算子初识
    初识Spark
  • 原文地址:https://www.cnblogs.com/barrywxx/p/12325202.html
Copyright © 2011-2022 走看看