zoukankan      html  css  js  c++  java
  • Spark学习笔记——手写数字识别

    import org.apache.spark.ml.classification.RandomForestClassifier
    import org.apache.spark.ml.regression.RandomForestRegressor
    import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, NaiveBayes, SVMWithSGD}
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.optimization.L1Updater
    import org.apache.spark.{SparkConf, SparkContext}
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.mllib.tree.{DecisionTree, RandomForest}
    import org.apache.spark.mllib.tree.configuration.Algo
    import org.apache.spark.mllib.tree.impurity.Entropy
    
    /**
      * Created by common on 17-5-17.
      */
    
    case class LabeledPic(
                           label: Int,
                           pic: List[Double] = List()
                         )
    
    object DigitRecognizer {
    
      def main(args: Array[String]): Unit = {
    
        val conf = new SparkConf().setAppName("DigitRecgonizer").setMaster("local")
        val sc = new SparkContext(conf)
        // 去掉第一行,sed 1d train.csv > train_noheader.csv
        val trainFile = "file:///media/common/工作/kaggle/DigitRecognizer/train_noheader.csv"
        val trainRawData = sc.textFile(trainFile)
        // 通过逗号对数据进行分割,生成数组的rdd
        val trainRecords = trainRawData.map(line => line.split(","))
    
        val trainData = trainRecords.map { r =>
          val label = r(0).toInt
          val features = r.slice(1, r.size).map(d => d.toDouble)
          LabeledPoint(label, Vectors.dense(features))
        }
    
    
        //    // 使用贝叶斯模型
        //    val nbModel = NaiveBayes.train(trainData)
        //
        //    val nbTotalCorrect = trainData.map { point =>
        //      if (nbModel.predict(point.features) == point.label) 1 else 0
        //    }.sum
        //    val nbAccuracy = nbTotalCorrect / trainData.count
        //
        //    println("贝叶斯模型正确率:" + nbAccuracy)
        //
        //    // 对测试数据进行预测
        //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
        //    // 通过逗号对数据进行分割,生成数组的rdd
        //    val testRecords = testRawData.map(line => line.split(","))
        //
        //    val testData = testRecords.map { r =>
        //      val features = r.map(d => d.toDouble)
        //      Vectors.dense(features)
        //    }
        //    val predictions = nbModel.predict(testData).map(p => p.toInt)
        //    // 保存预测结果
        //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict")
    
    
        //    // 使用线性回归模型
        //    val lrModel = new LogisticRegressionWithLBFGS()
        //      .setNumClasses(10)
        //      .run(trainData)
        //
        //    val lrTotalCorrect = trainData.map { point =>
        //      if (lrModel.predict(point.features) == point.label) 1 else 0
        //    }.sum
        //    val lrAccuracy = lrTotalCorrect / trainData.count
        //
        //    println("线性回归模型正确率:" + lrAccuracy)
        //
        //    // 对测试数据进行预测
        //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
        //    // 通过逗号对数据进行分割,生成数组的rdd
        //    val testRecords = testRawData.map(line => line.split(","))
        //
        //    val testData = testRecords.map { r =>
        //      val features = r.map(d => d.toDouble)
        //      Vectors.dense(features)
        //    }
        //    val predictions = lrModel.predict(testData).map(p => p.toInt)
        //    // 保存预测结果
        //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict1")
    
    
        //    // 使用决策树模型
        //    val maxTreeDepth = 10
        //    val numClass = 10
        //    val dtModel = DecisionTree.train(trainData, Algo.Classification, Entropy, maxTreeDepth, numClass)
        //
        //    val dtTotalCorrect = trainData.map { point =>
        //      if (dtModel.predict(point.features) == point.label) 1 else 0
        //    }.sum
        //    val dtAccuracy = dtTotalCorrect / trainData.count
        //
        //    println("决策树模型正确率:" + dtAccuracy)
        //
        //    // 对测试数据进行预测
        //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
        //    // 通过逗号对数据进行分割,生成数组的rdd
        //    val testRecords = testRawData.map(line => line.split(","))
        //
        //    val testData = testRecords.map { r =>
        //      val features = r.map(d => d.toDouble)
        //      Vectors.dense(features)
        //    }
        //    val predictions = dtModel.predict(testData).map(p => p.toInt)
        //    // 保存预测结果
        //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict2")
    
    
    //    // 使用随机森林模型
    //    val numClasses = 30
    //    val categoricalFeaturesInfo = Map[Int, Int]()
    //    val numTrees = 50
    //    val featureSubsetStrategy = "auto"
    //    val impurity = "gini"
    //    val maxDepth = 10
    //    val maxBins = 32
    //    val rtModel = RandomForest.trainClassifier(trainData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
    //
    //    val rtTotalCorrect = trainData.map { point =>
    //      if (rtModel.predict(point.features) == point.label) 1 else 0
    //    }.sum
    //    val rtAccuracy = rtTotalCorrect / trainData.count
    //
    //    println("随机森林模型正确率:" + rtAccuracy)
    //
    //    // 对测试数据进行预测
    //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
    //    // 通过逗号对数据进行分割,生成数组的rdd
    //    val testRecords = testRawData.map(line => line.split(","))
    //
    //    val testData = testRecords.map { r =>
    //      val features = r.map(d => d.toDouble)
    //      Vectors.dense(features)
    //    }
    //    val predictions = rtModel.predict(testData).map(p => p.toInt)
    //    // 保存预测结果
    //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict")
    
    
      }
    
    }
    
  • 相关阅读:
    discuz X3.2 自定义系统广告详解
    windows平台myeclipse+PDT+apache+xdebug调试php
    南浮的IT民工
    linux实践——编译安装两个apache
    如何使maven+jetty运行时不锁定js和css[转]
    linux实践——ubuntu搭建 svn 服务
    测试代码插件(插入代码块)
    FTP 文件接口按天批处理脚本实例
    7月份工作小结
    报表开发过程
  • 原文地址:https://www.cnblogs.com/tonglin0325/p/6906524.html
Copyright © 2011-2022 走看看