zoukankan      html  css  js  c++  java
  • Spark学习笔记——泰坦尼克生还预测

    package kaggle
    
    import org.apache.spark.SparkContext
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.{SQLContext, SparkSession}
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionWithSGD, NaiveBayes, SVMWithSGD}
    import org.apache.log4j.{Level, Logger}
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.stat.Statistics
    
    
    /**
      * Created by mi on 17-5-23.
      */
    
    
    object Titanic {
    
    
      def main(args: Array[String]) {
    
        //    val sparkSession = SparkSession.builder.
        //      master("local")
        //      .appName("spark session example")
        //      .getOrCreate()
        //    val rawData = sparkSession.read.csv("/home/mi/下载/kaggle/Titanic/nohead-train.csv")
        //    val d = rawData.map{p => p.asInstanceOf[person]}
        //    d.show()
    
        val conf = new SparkConf().setAppName("WordCount").setMaster("local")
        val sc = new SparkContext(conf)
        val sqlContext = new SQLContext(sc)
    
        //屏蔽日志
        Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
        Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
    
        // 读取数据
        val df = sqlContext.load("com.databricks.spark.csv", Map("path" -> "/home/mi/下载/kaggle/Titanic/train.csv", "header" -> "true"))
    
        // 分析年龄数据
        val ageAnalysis = df.rdd.filter(d => d(5) != null).map { d =>
          val age = d(5).toString.toDouble
          Vectors.dense(age)
        }
        val ageMean = Statistics.colStats(ageAnalysis).mean(0)
        val ageMax = Statistics.colStats(ageAnalysis).max(0)
        val ageMin = Statistics.colStats(ageAnalysis).min(0)
        val ageDiff = ageMax - ageMin
    
        // 分析船票价格数据
        val fareAnalysis = df.rdd.filter(d => d(9) != null).map { d =>
          val fare = d(9).toString.toDouble
          Vectors.dense(fare)
        }
        val fareMean = Statistics.colStats(fareAnalysis).mean(0)
        val fareMax = Statistics.colStats(fareAnalysis).max(0)
        val fareMin = Statistics.colStats(fareAnalysis).min(0)
        val fareDiff = fareMax - fareMin
    
    
        // 数据预处理
        val trainData = df.rdd.map { d =>
          val label = d(1).toString.toInt
          val sex = d(4) match {
            case "male" => 0.0
            case "female" => 1.0
          }
          val age = d(5) match {
            case null => (ageMean - ageMin) / ageDiff
            case _ => (d(5).toString().toDouble - ageMin) / ageDiff
          }
          val fare = d(9) match {
            case null => (fareMean - fareMin) / fareDiff
            case _ => (d(9).toString().toDouble - fareMin) / fareDiff
          }
    
          LabeledPoint(label, Vectors.dense(sex, age, fare))
        }
    
        // 切分数据集和测试集
        val Array(trainingData, testData) = trainData.randomSplit(Array(0.8, 0.2))
    
        // 训练数据
        val numIterations = 8
        val lrModel = new LogisticRegressionWithLBFGS().setNumClasses(2).run(trainingData)
        //    val svmModel = SVMWithSGD.train(trainingData, numIterations)
    
        val nbTotalCorrect = testData.map { point =>
          if (lrModel.predict(point.features) == point.label) 1 else 0
        }.sum
        val nbAccuracy = nbTotalCorrect / testData.count
    
        println("SVM模型正确率:" + nbAccuracy)
    
        // 预测
        // 读取数据
        val testdf = sqlContext.load("com.databricks.spark.csv", Map("path" -> "/home/mi/下载/kaggle/Titanic/test.csv", "header" -> "true"))
    
        // 分析测试集年龄数据
        val ageTestAnalysis = testdf.rdd.filter(d => d(4) != null).map { d =>
          val age = d(4).toString.toDouble
          Vectors.dense(age)
        }
        val ageTestMean = Statistics.colStats(ageTestAnalysis).mean(0)
        val ageTestMax = Statistics.colStats(ageTestAnalysis).max(0)
        val ageTestMin = Statistics.colStats(ageTestAnalysis).min(0)
        val ageTestDiff = ageTestMax - ageTestMin
    
        // 分析船票价格数据
        val fareTestAnalysis = testdf.rdd.filter(d => d(8) != null).map { d =>
          val fare = d(8).toString.toDouble
          Vectors.dense(fare)
        }
        val fareTestMean = Statistics.colStats(fareTestAnalysis).mean(0)
        val fareTestMax = Statistics.colStats(fareTestAnalysis).max(0)
        val fareTestMin = Statistics.colStats(fareTestAnalysis).min(0)
        val fareTestDiff = fareTestMax - fareTestMin
    
        // 数据预处理
        val data = testdf.rdd.map { d =>
          val sex = d(3) match {
            case "male" => 0.0
            case "female" => 1.0
          }
          val age = d(4) match {
            case null => (ageTestMean - ageTestMin) / ageTestDiff
            case _ => (d(4).toString().toDouble - ageTestMin) / ageTestDiff
          }
          val fare = d(8) match {
            case null => (fareTestMean - fareTestMin) / fareTestDiff
            case _ => (d(8).toString().toDouble - fareTestMin) / fareTestDiff
          }
    
          Vectors.dense(sex, age, fare)
        }
    
        val predictions = lrModel.predict(data).map(p => p.toInt)
        // 保存预测结果
        predictions.coalesce(1).saveAsTextFile("file:///home/mi/下载/kaggle/Titanic/test_predict")
      }
    }
    
  • 相关阅读:
    ubuntu shell插件
    通过更改服务器解决双系统ubuntu时间+8
    ubuntu安装mysql遇到的问题
    05 面向对象:构造方法&static&继承&方法 &final
    electron 大体结构
    js时间Date对象介绍及解决getTime转换为8点的问题
    Fiddler命令行和HTTP断点调试
    使用HTTP头去绕过WAF(bypasswaf)
    Linux下php5.3.3安装mcrypt扩展
    Error: Cannot find a valid baseurl for repo: epel
  • 原文地址:https://www.cnblogs.com/tonglin0325/p/6909157.html
Copyright © 2011-2022 走看看