zoukankan      html  css  js  c++  java
  • 大三寒假第十六天(逻辑斯蒂回归分类器)

    一定要对文本数据集进行预处理

    1.导入包

    import org.apache.spark.ml.feature.PCA
    import org.apache.spark.sql.Row
    import org.apache.spark.ml.linalg.{Vector,Vectors}
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.{Pipeline,PipelineModel}
    import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer,HashingTF, Tokenizer}
    import org.apache.spark.ml.classification.LogisticRegression
    import org.apache.spark.ml.classification.LogisticRegressionModel
    import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression}
    import org.apache.spark.sql.functions;
    import spark.implicits._
    

    2.定义class

    case class Adult(features:org.apache.spark.ml.linalg.Vector,label:String)

    3.读取数据,并转换成DF

    /*训练模型的原始数据集*/
    
    val df = sc.textFile("file:///usr/spark/sparkdata/adult.txt").map(_.split(","))
    .map(p=>Adult(Vectors.dense(p(0).toDouble,p(2).toDouble,p(4).toDouble,p(10).toDouble
    ,p(11).toDouble,p(12).toDouble),p(14).toString())).toDF()
    
    /*测试数据集*/
    
    val test = sc.textFile("file:///usr/spark/sparkdata/test.txt").map(_.split(","))
    .map(p=>Adult(Vectors.dense(p(0).toDouble,p(2).toDouble,p(4).toDouble,p(10).toDouble
    ,p(11).toDouble,p(12).toDouble),p(14).toString())).toDF()
    

      

    4.如果维度过多,用PCA主成分分析进行降维(6维变3维)

    //setK()填维度,fit()填df数据
    val pca = new PCA().setInputCol("features")
    .setOutputCol("PCAfeatures").setK(3).fit(df)
    
    //用pca模型进行转换得到新的df
    val trainningdata = pca.transform(df)
    

    5.分别获取标签列,和特征列,进行索引和重命名

    //获取标签列
    
    val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(trainningdata)
    
    //获取特征列
    
    val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(trainningdata)

    6.设置逻辑斯蒂参数

    val lr = new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(10)

    7.设置一个convertLabel 把预测类型重新转换成字符型

    val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)

    8.构建pipeline,设置stage,并用fit()进行训练

    //算法
    
    val lrPipeline = new Pipeline().setStages(Array(labelIndexer,featureIndexer,lr,labelConvert))
    
    //模型
    
    val lrPipelineModel = lrPipeline.fit(tranningdata)
    

    9.用构建好的模型进行预测

    //调用模型的transform方法,对测试数据集进行预测,(先降维)
    
    val lrPredictions = lrPipelineModel.transform(test)
    

    10.输出预测结果

    lrPredictions.select("predictedLabel","label","features","probability")
    .collect().foreach{
        case Row
               (predictedLabel:String,label:String,features:Vector,prob:Vector)
           =>
    println(s"($label,$features)-->prob=$prob,predictedLabel=$predictedLabel")
    }
    

    11.模型评估

      

    val evaluator = new  MulticlassClassificationEvaluator()
    .setLabelCol("indexedLabel")
    .setPredictionCol("prediction")
    
    
    val  lrAccuacry = evaluator.evaluate(lrpridicton)

      

      

      

      

     

  • 相关阅读:
    主流的Nosql数据库的对比
    CCF考试真题题解
    排序
    2017-10-03-afternoon
    POJ——T 2728 Desert King
    51Nod——T 1686 第K大区间
    POJ——T 2976 Dropping tests
    2017-10-02-afternoon
    入参是小数的String,返回小数乘以100的String
    银联支付踩过的坑
  • 原文地址:https://www.cnblogs.com/sakura-xxg/p/15766382.html
Copyright © 2011-2022 走看看