zoukankan      html  css  js  c++  java
  • Spark学习笔记——手写数字识别

    import org.apache.spark.ml.classification.RandomForestClassifier
    import org.apache.spark.ml.regression.RandomForestRegressor
    import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, NaiveBayes, SVMWithSGD}
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.optimization.L1Updater
    import org.apache.spark.{SparkConf, SparkContext}
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.mllib.tree.{DecisionTree, RandomForest}
    import org.apache.spark.mllib.tree.configuration.Algo
    import org.apache.spark.mllib.tree.impurity.Entropy
    
    /**
      * Created by common on 17-5-17.
      */
    
    case class LabeledPic(
                           label: Int,
                           pic: List[Double] = List()
                         )
    
    object DigitRecognizer {
    
      def main(args: Array[String]): Unit = {
    
        val conf = new SparkConf().setAppName("DigitRecgonizer").setMaster("local")
        val sc = new SparkContext(conf)
        // 去掉第一行,sed 1d train.csv > train_noheader.csv
        val trainFile = "file:///media/common/工作/kaggle/DigitRecognizer/train_noheader.csv"
        val trainRawData = sc.textFile(trainFile)
        // 通过逗号对数据进行分割,生成数组的rdd
        val trainRecords = trainRawData.map(line => line.split(","))
    
        val trainData = trainRecords.map { r =>
          val label = r(0).toInt
          val features = r.slice(1, r.size).map(d => d.toDouble)
          LabeledPoint(label, Vectors.dense(features))
        }
    
    
        //    // 使用贝叶斯模型
        //    val nbModel = NaiveBayes.train(trainData)
        //
        //    val nbTotalCorrect = trainData.map { point =>
        //      if (nbModel.predict(point.features) == point.label) 1 else 0
        //    }.sum
        //    val nbAccuracy = nbTotalCorrect / trainData.count
        //
        //    println("贝叶斯模型正确率:" + nbAccuracy)
        //
        //    // 对测试数据进行预测
        //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
        //    // 通过逗号对数据进行分割,生成数组的rdd
        //    val testRecords = testRawData.map(line => line.split(","))
        //
        //    val testData = testRecords.map { r =>
        //      val features = r.map(d => d.toDouble)
        //      Vectors.dense(features)
        //    }
        //    val predictions = nbModel.predict(testData).map(p => p.toInt)
        //    // 保存预测结果
        //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict")
    
    
        //    // 使用线性回归模型
        //    val lrModel = new LogisticRegressionWithLBFGS()
        //      .setNumClasses(10)
        //      .run(trainData)
        //
        //    val lrTotalCorrect = trainData.map { point =>
        //      if (lrModel.predict(point.features) == point.label) 1 else 0
        //    }.sum
        //    val lrAccuracy = lrTotalCorrect / trainData.count
        //
        //    println("线性回归模型正确率:" + lrAccuracy)
        //
        //    // 对测试数据进行预测
        //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
        //    // 通过逗号对数据进行分割,生成数组的rdd
        //    val testRecords = testRawData.map(line => line.split(","))
        //
        //    val testData = testRecords.map { r =>
        //      val features = r.map(d => d.toDouble)
        //      Vectors.dense(features)
        //    }
        //    val predictions = lrModel.predict(testData).map(p => p.toInt)
        //    // 保存预测结果
        //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict1")
    
    
        //    // 使用决策树模型
        //    val maxTreeDepth = 10
        //    val numClass = 10
        //    val dtModel = DecisionTree.train(trainData, Algo.Classification, Entropy, maxTreeDepth, numClass)
        //
        //    val dtTotalCorrect = trainData.map { point =>
        //      if (dtModel.predict(point.features) == point.label) 1 else 0
        //    }.sum
        //    val dtAccuracy = dtTotalCorrect / trainData.count
        //
        //    println("决策树模型正确率:" + dtAccuracy)
        //
        //    // 对测试数据进行预测
        //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
        //    // 通过逗号对数据进行分割,生成数组的rdd
        //    val testRecords = testRawData.map(line => line.split(","))
        //
        //    val testData = testRecords.map { r =>
        //      val features = r.map(d => d.toDouble)
        //      Vectors.dense(features)
        //    }
        //    val predictions = dtModel.predict(testData).map(p => p.toInt)
        //    // 保存预测结果
        //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict2")
    
    
    //    // 使用随机森林模型
    //    val numClasses = 30
    //    val categoricalFeaturesInfo = Map[Int, Int]()
    //    val numTrees = 50
    //    val featureSubsetStrategy = "auto"
    //    val impurity = "gini"
    //    val maxDepth = 10
    //    val maxBins = 32
    //    val rtModel = RandomForest.trainClassifier(trainData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
    //
    //    val rtTotalCorrect = trainData.map { point =>
    //      if (rtModel.predict(point.features) == point.label) 1 else 0
    //    }.sum
    //    val rtAccuracy = rtTotalCorrect / trainData.count
    //
    //    println("随机森林模型正确率:" + rtAccuracy)
    //
    //    // 对测试数据进行预测
    //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
    //    // 通过逗号对数据进行分割,生成数组的rdd
    //    val testRecords = testRawData.map(line => line.split(","))
    //
    //    val testData = testRecords.map { r =>
    //      val features = r.map(d => d.toDouble)
    //      Vectors.dense(features)
    //    }
    //    val predictions = rtModel.predict(testData).map(p => p.toInt)
    //    // 保存预测结果
    //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict")
    
    
      }
    
    }
    
  • 相关阅读:
    cf1100 F. Ivan and Burgers
    cf 1033 D. Divisors
    LeetCode 17. 电话号码的字母组合
    LeetCode 491. 递增的子序列
    LeetCode 459.重复的子字符串
    LeetCode 504. 七进制数
    LeetCode 3.无重复字符的最长子串
    LeetCode 16.06. 最小差
    LeetCode 77. 组合
    LeetCode 611. 有效三角形个数
  • 原文地址:https://www.cnblogs.com/tonglin0325/p/6906524.html
Copyright © 2011-2022 走看看