zoukankan      html  css  js  c++  java
  • 基于Spark ML的Titanic Challenge (Top 6%)

    下面代码按照之前参加Kaggle的python代码改写,只完成了模型的训练过程,还需要对test集的数据进行转换和对test集进行预测。
    scala 2.11.12
    spark 2.2.2

    package ML.Titanic
    
    import org.apache.spark.SparkContext
    import org.apache.spark.sql._
    import org.apache.spark.sql.functions._
    import org.apache.spark.ml.feature.Bucketizer
    import org.apache.spark.ml.feature.QuantileDiscretizer
    import org.apache.spark.ml.feature.StringIndexer
    import org.apache.spark.ml.feature.OneHotEncoder
    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.feature.VectorAssembler
    import org.apache.spark.ml.classification.GBTClassifier
    import org.apache.spark.ml.tuning.ParamGridBuilder
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.tuning.{TrainValidationSplit, TrainValidationSplitModel}
    import org.apache.spark.ml.PipelineModel
    import org.apache.spark.sql.types._
    
    
    /**
      * GBTClassifier for predicting survival in the Titanic ship
      */
    object TitanicChallenge {
    
      def main(args: Array[String]) {
    
        val spark = SparkSession.builder.
          master("local[*]")
          .appName("example")
          .config("spark.sql.shuffle.partitions", 20)
          .config("spark.default.parallelism", 20)
          .config("spark.driver.memory", "4G")
          .config("spark.memory.fraction", 0.75)
          .getOrCreate()
        val sc = spark.sparkContext
    
        spark.sparkContext.setLogLevel("ERROR")
    
        val schemaArray = StructType(Array(
          StructField("PassengerId", IntegerType, true),
          StructField("Survived", IntegerType, true),
          StructField("Pclass", IntegerType, true),
          StructField("Name", StringType, true),
          StructField("Sex", StringType, true),
          StructField("Age", FloatType, true),
          StructField("SibSp", IntegerType, true),
          StructField("Parch", IntegerType, true),
          StructField("Ticket", StringType, true),
          StructField("Fare", FloatType, true),
          StructField("Cabin", StringType, true),
          StructField("Embarked", StringType, true)
        ))
    
        val path = "Titanic/"
        val df = spark.read
          .option("header", "true")
          .schema(schemaArray)
          .csv(path + "train.csv")
          .drop("PassengerId")
    //    df.cache()
    
        val utils = new TitanicChallenge(spark)
        val df2 = utils.transCabin(df)
        val df3 = utils.transTicket(sc, df2)
        val df4 = utils.transEmbarked(df3)
        val df5 = utils.extractTitle(sc, df4)
        val df6 = utils.transAge(sc, df5)
        val df7 = utils.categorizeAge(df6)
        val df8 = utils.createFellow(df7)
        val df9 = utils.categorizeFellow(df8)
        val df10 = utils.extractFName(df9)
        val df11 = utils.transFare(df10)
    
        val prePipelineDF = df11.select("Survived", "Pclass", "Sex",
          "Age_categorized", "fellow_type", "Fare_categorized",
          "Embarked", "Cabin", "Ticket",
          "Title", "family_type")
    
    //    prePipelineDF.show(1)
    //    +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+
    //    |Survived|Pclass| Sex|Age_categorized|fellow_type|Fare_categorized|Embarked|Cabin|Ticket|Title|family_type|
    //    +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+
    //    |       0|     3|male|            3.0|      Small|             0.0|       S|    U|     0|   Mr|          0|
    //    +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+
    
        val (df_indexed, colsTrain) = utils.index_onehot(prePipelineDF)
        df_indexed.cache()
    
        //训练模型
        val validatorModel = utils.trainData(df_indexed, colsTrain)
    
        //打印最优模型的参数
        val bestModel = validatorModel.bestModel
        println(bestModel.asInstanceOf[PipelineModel].stages.last.extractParamMap)
    
        //打印各模型的成绩和参数
        val paramsAndMetrics = validatorModel.validationMetrics
          .zip(validatorModel.getEstimatorParamMaps)
          .sortBy(-_._1)
        paramsAndMetrics.foreach { case (metric, params) =>
          println(metric)
          println(params)
          println()
        }
    
        validatorModel.write.overwrite().save(path + "Titanic_gbtc")
    
        spark.stop()
      }
    }
    
    class TitanicChallenge(private val spark: SparkSession) extends Serializable {
    
      import spark.implicits._
    
      //Cabin,用“U”填充null,并提取Cabin的首字母
      def transCabin(df: Dataset[Row]): Dataset[Row] = {
        df.na.fill("U", Seq("Cabin"))
          .withColumn("Cabin", substring($"Cabin", 0, 1))
      }
    
      //
      def transTicket(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = {
    
        ////提取船票的号码,如“A/5 21171”中的21171
        val medDF1 = df.withColumn("Ticket", split($"Ticket", " "))
          .withColumn("Ticket", $"Ticket"(size($"Ticket").minus(1)))
          .filter($"Ticket" =!= "LINE")//去掉某种特殊的船票
    
        //对船票号进行分类,小于四位号码的为“1”,四位号码的以第一个数字开头,后面接上“0”,大于4位号码的,取前三个数字开头。如21171变为211
        val ticketTransUdf = udf((ticket: String) => {
          if (ticket.length < 4) {
            "1"
          } else if (ticket.length == 4){
            ticket(0)+"0"
          } else {
            ticket.slice(0, 3)
          }
        })
        val medDF2 = medDF1.withColumn("Ticket", ticketTransUdf($"Ticket"))
    
        //将数量小于等于5的类别统一归为“0”。先统计小于5的名单,然后用udf进行转换。
        val filterList = medDF2.groupBy($"Ticket").count()
          .filter($"count" <= 5)
          .map(row => row.getString(0))
          .collect.toList
    
        val filterList_bc = sc.broadcast(filterList)
    
        val ticketTransAdjustUdf = udf((subticket: String) => {
          if (filterList_bc.value.contains(subticket)) "0"
          else subticket
        })
    
        medDF2.withColumn("Ticket", ticketTransAdjustUdf($"Ticket"))
      }
    
      //用“S”填充null
      def transEmbarked(df: Dataset[Row]): Dataset[Row] = {
        df.na.fill("S", Seq("Embarked"))
      }
    
      def extractTitle(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = {
        val regex = ".*, (.*?)\..*"
    
        //对头衔进行归类
        val titlesMap = Map(
          "Capt"-> "Officer",
          "Col"-> "Officer",
          "Major"-> "Officer",
          "Jonkheer"-> "Royalty",
          "Don"-> "Royalty",
          "Sir" -> "Royalty",
          "Dr"-> "Officer",
          "Rev"-> "Officer",
          "the Countess"->"Royalty",
          "Mme"-> "Mrs",
          "Mlle"-> "Miss",
          "Ms"-> "Mrs",
          "Mr" -> "Mr",
          "Mrs" -> "Mrs",
          "Miss" -> "Miss",
          "Master" -> "Master",
          "Lady" -> "Royalty"
        )
    
        val titlesMap_bc = sc.broadcast(titlesMap)
    
        df.withColumn("Title", regexp_extract(($"Name"), regex, 1))
          .na.replace("Title", titlesMap_bc.value)
      }
    
      //根据null age的records对应的Pclass和Name_final分组后的平均来填充缺失age。
      // 首先,生成分组key,并获取分组后的平均年龄map。然后广播map,当Age为null时,用udf返回需要填充的值。
      def transAge(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = {
        val medDF = df.withColumn("Pclass_Title_key", concat($"Title", $"Pclass"))
        val meanAgeMap = medDF.groupBy("Pclass_Title_key")
          .mean("Age")
          .map(row => (row.getString(0), row.getDouble(1)))
          .collect().toMap
    
        val meanAgeMap_bc = sc.broadcast(meanAgeMap)
    
        val fillAgeUdf = udf((comb_key: String) => meanAgeMap_bc.value.getOrElse(comb_key, 0.0))
    
        medDF.withColumn("Age", when($"Age".isNull, fillAgeUdf($"Pclass_Title_key")).otherwise($"Age"))
      }
    
      //对Age进行分类
      def categorizeAge(df: Dataset[Row]): Dataset[Row] = {
        val ageBucketBorders = 0.0 +: (10.0 to 60.0 by 5.0).toArray :+ 150.0
        val ageBucketer = new Bucketizer().setSplits(ageBucketBorders).setInputCol("Age").setOutputCol("Age_categorized")
        ageBucketer.transform(df).drop("Pclass_Title_key")
      }
    
      //将SibSp和Parch相加,得出同行人数
      def createFellow(df: Dataset[Row]): Dataset[Row] = {
        df.withColumn("fellow", $"SibSp" + $"Parch")
      }
    
      //fellow_type, 对fellow进行分类。此处其实可以留到pipeline部分一次性完成。
      def categorizeFellow(df: Dataset[Row]): Dataset[Row] = {
        df.withColumn("fellow_type", when($"fellow" === 0, "Alone")
          .when($"fellow" <= 3, "Small")
          .otherwise("Large"))
      }
    
      def extractFName(df: Dataset[Row]): Dataset[Row] = {
    
        //检查df是否有Survived和fellow列
        if (!df.columns.contains("Survived") || !df.columns.contains("fellow")){
          throw new IllegalArgumentException(
            """
              |Check if the argument is a training set or if this training set contains column named "fellow"
            """.stripMargin)
        }
    
        //FName,提取家庭名称。例如:"Johnston, Miss. Catherine Helen ""Carrie""" 提取出Johnston
        // 由于spark的读取csv时,如果有引号,读取就会出现多余的引号,所以除了split逗号,还要再split一次引号。
        val medDF = df
          .withColumn("FArray", split($"Name", ","))
          .withColumn("FName", expr("FArray[0]"))
          .withColumn("FArray", split($"FName", """))
          .withColumn("FName", $"FArray"(size($"FArray").minus(1)))
    
        //family_type,分为三类,第一类是60岁以下女性遇难的家庭,第二类是18岁以上男性存活的家庭,第三类其他。
        val femaleDiedFamily_filter = $"Sex" === "female" and $"Age" < 60 and $"Survived" === 0 and $"fellow" > 0
    
        val maleSurvivedFamily_filter = $"Sex" === "male" and $"Age" >= 18 and $"Survived" === 1 and $"fellow" > 1
    
        val resDF = medDF.withColumn("family_type", when(femaleDiedFamily_filter, 1)
          .when(maleSurvivedFamily_filter, 2).otherwise(0))
    
        //familyTable,家庭分类名单,用于后续test集的转化。此处用${FName}_${family_type}的形式保存。
        resDF.filter($"family_type".isin(1,2))
          .select(concat($"FName", lit("_"), $"family_type"))
          .dropDuplicates()
          .write.format("text").mode("overwrite").save("familyTable")
    
        //如果需要直接收集成Map的话,可用下面代码。
        // 此代码先利用mapPartitions对各分块的数据进行聚合,降低直接调用count而使driver挂掉的风险。
        //另外新建一个默认Set是为了防止某个partition并没有数据的情况(出现概率可能比较少),
        // 从而使得Set的类型变为Set[_>:Tuple]而不能直接flatten
    
        // val familyMap = df10
        //   .filter($"family_type" === 1 || $"family_type" === 2)
        //   .select("FName", "family_type")
        //   .rdd
        //   .mapPartitions{iter => {
        //       if (!iter.isEmpty) {
        //       Iterator(iter.map(row => (row.getString(0), row.getInt(1))).toSet)}
        //       else Iterator(Set(("defualt", 9)))}
        //                 }
        //   .collect()
        //   .flatten
        //   .toMap
    
        resDF
      }
    
      //Fare。首先去掉缺失的(test集合中有一个,如果量多的话,也可以像Age那样通过头衔,年龄等因数来推断)
      //然后对Fare进行分类
      def transFare(df: Dataset[Row]): Dataset[Row] = {
    
        val medDF = df.na.drop("any", Seq("Fare"))
        val fareBucketer = new QuantileDiscretizer()
          .setInputCol("Fare")
          .setOutputCol("Fare_categorized")
          .setNumBuckets(4)
    
        fareBucketer.fit(medDF).transform(medDF)
      }
    
      def index_onehot(df: Dataset[Row]): Tuple2[Dataset[Row], Array[String]] = {
        val stringCols = Array("Sex","fellow_type", "Embarked", "Cabin", "Ticket", "Title")
        val subOneHotCols = stringCols.map(cname => s"${cname}_index")
        val index_transformers: Array[org.apache.spark.ml.PipelineStage] = stringCols.map(
          cname => new StringIndexer()
            .setInputCol(cname)
            .setOutputCol(s"${cname}_index")
            .setHandleInvalid("skip")
        )
    
    
        val oneHotCols = subOneHotCols ++ Array("Pclass", "Age_categorized", "Fare_categorized", "family_type")
        val vectorCols = oneHotCols.map(cname => s"${cname}_encoded")
        val encode_transformers: Array[org.apache.spark.ml.PipelineStage] = oneHotCols.map(
          cname => new OneHotEncoder()
            .setInputCol(cname)
            .setOutputCol(s"${cname}_encoded")
        )
    
        val pipelineStage = index_transformers ++ encode_transformers
        val index_onehot_pipeline = new Pipeline().setStages(pipelineStage)
        val index_onehot_pipelineModel = index_onehot_pipeline.fit(df)
    
        val resDF = index_onehot_pipelineModel.transform(df).drop(stringCols:_*).drop(subOneHotCols:_*)
        println(resDF.columns.size)
        (resDF, vectorCols)
      }
    
      def trainData(df: Dataset[Row], vectorCols: Array[String]): TrainValidationSplitModel = {
        //separate and model pipeline,包含划分label和features,机器学习模型的pipeline
        val vectorAssembler = new VectorAssembler()
          .setInputCols(vectorCols)
          .setOutputCol("features")
    
        val gbtc = new GBTClassifier()
          .setLabelCol("Survived")
          .setFeaturesCol("features")
          .setPredictionCol("prediction")
    
        val pipeline = new Pipeline().setStages(Array(vectorAssembler, gbtc))
    
        val paramGrid = new ParamGridBuilder()
          .addGrid(gbtc.stepSize, Seq(0.1))
          .addGrid(gbtc.maxDepth, Seq(5))
          .addGrid(gbtc.maxIter, Seq(20))
          .build()
    
        val multiclassEval = new MulticlassClassificationEvaluator()
          .setLabelCol("Survived")
          .setPredictionCol("prediction")
          .setMetricName("accuracy")
    
        val tvs = new TrainValidationSplit()
          .setTrainRatio(0.75)
          .setEstimatorParamMaps(paramGrid)
          .setEstimator(pipeline)
          .setEvaluator(multiclassEval)
    
        tvs.fit(df)
      }
    }
    
  • 相关阅读:
    python获取股票数据接口
    Excel使用VBA读取实时WebService股票数据
    安装Pycharm
    Pycharm2019使用
    KLine
    pycharm下用mysql
    新浪股票接口
    SpringBoot整合持久层技术--(二)MyBatis
    SpringBoot整合持久层技术--(一)JdbcTemplate
    SpringBoot整合WEB开发--(十)配置AOP
  • 原文地址:https://www.cnblogs.com/code2one/p/9872546.html
Copyright © 2011-2022 走看看