zoukankan      html  css  js  c++  java
  • Spark 逻辑回归LogisticRegression

    1、概念

    逻辑回归是预测分类相应的常用方法。广义线性回归的一个特例是预测结果的概率。在spark.ml逻辑回归中,可以使用二项逻辑回归来预测二元结果,

    或者可以使用多项逻辑回归来预测多类结果。使用该family参数在这两种算法之间选择,或者保持不设置(缺省auto),Spark将推断出正确的变量。 通过将family参数设置为“多项式”,可以将多项逻辑回归用于二进制分类。它将产生两组系数和两个截距.
    在分类问题中,我们尝试预测的是结果是否属于某一个类(例如正确或错误)。分类问题的例子有:判断一封电子邮件是否是垃圾邮件;判断一次金融交易是否是欺诈;

    2、code,参考地址:https://github.com/asker124143222/spark-demo

    package com.home.spark.ml
    
    import org.apache.spark.SparkConf
    import org.apache.spark.ml.Pipeline
    import org.apache.spark.ml.classification.LogisticRegression
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
    import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
    import org.apache.spark.ml.linalg.{Vector, Vectors}
    import org.apache.spark.sql.{Dataset, Row, SparkSession}
    
    /**
      * @Description: 逻辑回归,二项分类预测
      *
      **/
    object Ex_BinomialLogisticRegression {
      def main(args: Array[String]): Unit = {
        val conf = new SparkConf(true).setMaster("local[*]").setAppName("spark ml label")
        val spark = SparkSession.builder().config(conf).getOrCreate()
        //rdd转换成df或者ds需要SparkSession实例的隐式转换
        //导入隐式转换,注意这里的spark不是包名,而是SparkSession的对象名
        import spark.implicits._
    
        val data = spark.sparkContext.textFile("input/iris.data.txt")
          .map(_.split(","))
          .map(a => Iris(
            Vectors.dense(a(0).toDouble, a(1).toDouble, a(2).toDouble, a(3).toDouble),
            a(4))
          ).toDF()
        data.show()
    
        data.createOrReplaceTempView("iris")
    
        val TotalCount = spark.sql("select count(*) from iris")
        println("记录数: " + TotalCount.collect().take(1).mkString)
    
        //二项预测,由于样本数据有三类数据,排除Iris-setosa
        val df = spark.sql("select * from iris where label!='Iris-setosa'")
        df.map(r => r(1) + " : " + r(0)).collect().take(10).foreach(println)
        println("过滤后的记录数: " + df.count())
    
    
        /* VectorIndexer
        提高决策树或随机森林等ML方法的分类效果。
        VectorIndexer是对数据集特征向量中的类别(离散值)特征(index categorical features categorical features )进行编号。
        它能够自动判断那些特征是离散值型的特征,并对他们进行编号,
        具体做法是通过设置一个maxCategories,特征向量中某一个特征不重复取值个数小于maxCategories,则被重新编号为0~K(K<=maxCategories-1)。
        某一个特征不重复取值个数大于maxCategories,则该特征视为连续值,不会重新编号(不会发生任何改变)
        假设maxCategories=5,那么特征列中非重复取值小于等于5的列将被重新索引
        为了索引的稳定性,规定如果这个特征值为0,则一定会被编号成0,这样可以保证向量的稀疏度
        maxCategories缺省是20
        */
        //对特征列和标签列进行索引转换
        val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df)
        val featureIndexer = new VectorIndexer()
    //      .setMaxCategories(5) //设置为5后,由于特征列的非重复值个数都大于5,所以不会发生任何转换,也就没有意义
          .setInputCol("features").setOutputCol("indexedFeatures")
          .fit(df)
    
    
        //对原数据集划分训练数据(70%)和测试数据(30%)
        val Array(trainingData, testData): Array[Dataset[Row]] = df.randomSplit(Array(0.7, 0.3))
    
        /**
          * LR建模
          * setMaxIter设置最大迭代次数(默认100),具体迭代次数可能在不足最大迭代次数停止
          * setTol设置容错(默认1e-6),每次迭代会计算一个误差,误差值随着迭代次数增加而减小,当误差小于设置容错,则停止迭代
          * setRegParam设置正则化项系数(默认0),正则化主要用于防止过拟合现象,如果数据集较小,特征维数又多,易出现过拟合,考虑增大正则化系数
          * setElasticNetParam正则化范式比(默认0),正则化有两种方式:L1(Lasso)和L2(Ridge),L1用于特征的稀疏化,L2用于防止过拟合
          * setLabelCol设置标签列
          * setFeaturesCol设置特征列
          * setPredictionCol设置预测列
          * setThreshold设置二分类阈值
          */
        //设置逻辑回归参数
        val lr = new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setFamily()
          .setMaxIter(100).setRegParam(0.3).setElasticNetParam(0.8)
    
        //转换器,将预测的类别重新转成字符型
        val labelConverter = new IndexToString()
          .setInputCol("prediction")
          .setOutputCol("predectionLabel")
          .setLabels(labelIndexer.labels)
    
    
        //建立工作流
        val lrPipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, lr, labelConverter))
    
        //生成模型
        val model = lrPipeline.fit(trainingData)
    
        //预测
        val result = model.transform(testData)
    
        //打印结果
        result.show(200, false)
    
        //模型评估,预测准确性和错误率
        val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
        val lrAccuracy: Double = evaluator.evaluate(result)
    
        println("Test Error = " + (1.0 - lrAccuracy))
    
        spark.stop()
      }
    }
    
    
    case class Iris(features: Vector, label: String)

    3、result

    +-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+
    |features         |label          |indexedLabel|indexedFeatures    |rawPrediction                               |probability                             |prediction|predectionLabel|
    +-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+
    |[4.9,2.4,3.3,1.0]|Iris-versicolor|0.0         |[4.9,3.0,3.3,0.0]  |[1.0071037675553336,-1.0071037675553336]    |[0.7324529695042751,0.2675470304957249] |0.0       |Iris-versicolor|
    |[5.0,2.0,3.5,1.0]|Iris-versicolor|0.0         |[5.0,0.0,3.5,0.0]  |[0.938177922699384,-0.938177922699384]      |[0.7187314594034615,0.2812685405965385] |0.0       |Iris-versicolor|
    |[5.6,2.5,3.9,1.1]|Iris-versicolor|0.0         |[5.6,4.0,3.9,1.0]  |[0.7107814076350716,-0.7107814076350716]    |[0.6705737993354417,0.3294262006645583] |0.0       |Iris-versicolor|
    |[5.6,2.9,3.6,1.3]|Iris-versicolor|0.0         |[5.6,8.0,3.6,3.0]  |[0.6350805242141693,-0.6350805242141693]    |[0.6536405613705153,0.3463594386294846] |0.0       |Iris-versicolor|
    |[5.8,2.7,4.1,1.0]|Iris-versicolor|0.0         |[5.8,6.0,4.1,0.0]  |[0.7314003881315354,-0.7314003881315354]    |[0.6751125028597408,0.32488749714025916]|0.0       |Iris-versicolor|
    |[6.1,2.8,4.7,1.2]|Iris-versicolor|0.0         |[6.1,7.0,4.7,2.0]  |[0.34553320285886,-0.34553320285886]        |[0.5855339747983552,0.41446602520164466]|0.0       |Iris-versicolor|
    |[6.2,2.2,4.5,1.5]|Iris-versicolor|0.0         |[6.2,1.0,4.5,5.0]  |[0.14582457165756946,-0.14582457165756946]  |[0.5363916772629104,0.46360832273708963]|0.0       |Iris-versicolor|
    |[6.4,2.9,4.3,1.3]|Iris-versicolor|0.0         |[6.4,8.0,4.3,3.0]  |[0.39384006721834597,-0.39384006721834597]  |[0.597206774507057,0.40279322549294305] |0.0       |Iris-versicolor|
    |[6.6,3.0,4.4,1.4]|Iris-versicolor|0.0         |[6.6,9.0,4.4,4.0]  |[0.2698323194379575,-0.2698323194379575]    |[0.5670517391689078,0.43294826083109217]|0.0       |Iris-versicolor|
    |[6.7,3.0,5.0,1.7]|Iris-versicolor|0.0         |[6.7,9.0,5.0,7.0]  |[-0.20557969118713126,0.20557969118713126]  |[0.44878532413929256,0.5512146758607075]|1.0       |Iris-virginica |
    |[6.7,3.1,4.4,1.4]|Iris-versicolor|0.0         |[6.7,10.0,4.4,4.0] |[0.2698323194379575,-0.2698323194379575]    |[0.5670517391689078,0.43294826083109217]|0.0       |Iris-versicolor|
    |[7.0,3.2,4.7,1.4]|Iris-versicolor|0.0         |[7.0,11.0,4.7,4.0] |[0.16644355215403328,-0.16644355215403328]  |[0.5415150896404186,0.4584849103595813] |0.0       |Iris-versicolor|
    |[4.9,2.5,4.5,1.7]|Iris-virginica |1.0         |[4.9,4.0,4.5,7.0]  |[-0.033265079047257284,0.033265079047257284]|[0.49168449702809164,0.5083155029719083]|1.0       |Iris-virginica |
    |[5.4,3.0,4.5,1.5]|Iris-versicolor|0.0         |[5.4,9.0,4.5,5.0]  |[0.14582457165756946,-0.14582457165756946]  |[0.5363916772629104,0.46360832273708963]|0.0       |Iris-versicolor|
    |[5.6,2.8,4.9,2.0]|Iris-virginica |1.0         |[5.6,7.0,4.9,10.0] |[-0.43975124481639627,0.43975124481639627]  |[0.39180024423019144,0.6081997557698086]|1.0       |Iris-virginica |
    |[5.6,3.0,4.1,1.3]|Iris-versicolor|0.0         |[5.6,9.0,4.1,3.0]  |[0.4627659120742955,-0.4627659120742955]    |[0.6136701219061476,0.38632987809385244]|0.0       |Iris-versicolor|
    |[5.8,2.7,3.9,1.2]|Iris-versicolor|0.0         |[5.8,6.0,3.9,2.0]  |[0.6212365822826582,-0.6212365822826582]    |[0.6504997376392441,0.34950026236075604]|0.0       |Iris-versicolor|
    |[5.8,2.7,5.1,1.9]|Iris-virginica |1.0         |[5.8,6.0,5.1,9.0]  |[-0.419132264319932,0.419132264319932]      |[0.3967244102962335,0.6032755897037665] |1.0       |Iris-virginica |
    |[5.9,3.0,5.1,1.8]|Iris-virginica |1.0         |[5.9,9.0,5.1,8.0]  |[-0.32958743896751885,0.32958743896751885]  |[0.4183410089972438,0.5816589910027563] |1.0       |Iris-virginica |
    |[6.0,2.9,4.5,1.5]|Iris-versicolor|0.0         |[6.0,8.0,4.5,5.0]  |[0.14582457165756946,-0.14582457165756946]  |[0.5363916772629104,0.46360832273708963]|0.0       |Iris-versicolor|
    |[6.1,3.0,4.6,1.4]|Iris-versicolor|0.0         |[6.1,9.0,4.6,4.0]  |[0.20090647458200817,-0.20090647458200817]  |[0.5500583546439539,0.4499416453560461] |0.0       |Iris-versicolor|
    |[6.2,3.4,5.4,2.3]|Iris-virginica |1.0         |[6.2,13.0,5.4,13.0]|[-0.8807003330135101,0.8807003330135101]    |[0.29303267372325625,0.7069673262767437]|1.0       |Iris-virginica |
    |[6.7,3.1,4.7,1.5]|Iris-versicolor|0.0         |[6.7,10.0,4.7,5.0] |[0.07689872680162013,-0.07689872680162013]  |[0.5192152136737482,0.48078478632625177]|0.0       |Iris-versicolor|
    |[6.7,3.3,5.7,2.5]|Iris-virginica |1.0         |[6.7,12.0,5.7,15.0]|[-1.163178751002261,1.163178751002261]      |[0.23809016943453823,0.7619098305654617]|1.0       |Iris-virginica |
    |[6.8,3.0,5.5,2.1]|Iris-virginica |1.0         |[6.8,9.0,5.5,11.0] |[-0.7360736047366578,0.7360736047366578]    |[0.32386333429517283,0.6761366657048272]|1.0       |Iris-virginica |
    |[6.9,3.1,5.4,2.1]|Iris-virginica |1.0         |[6.9,10.0,5.4,11.0]|[-0.7016106823086834,0.7016106823086834]    |[0.33145521561995817,0.6685447843800418]|1.0       |Iris-virginica |
    |[7.2,3.6,6.1,2.5]|Iris-virginica |1.0         |[7.2,14.0,6.1,15.0]|[-1.3010304407141597,1.3010304407141597]    |[0.21399164655179387,0.7860083534482062]|1.0       |Iris-virginica |
    |[7.7,2.8,6.7,2.0]|Iris-virginica |1.0         |[7.7,7.0,6.7,10.0] |[-1.0600838485199424,1.0600838485199424]    |[0.2572934314622856,0.7427065685377143] |1.0       |Iris-virginica |
    |[7.7,3.0,6.1,2.3]|Iris-virginica |1.0         |[7.7,9.0,6.1,13.0] |[-1.1219407900093334,1.1219407900093334]    |[0.24565146441425778,0.7543485355857422]|1.0       |Iris-virginica |
    |[7.9,3.8,6.4,2.0]|Iris-virginica |1.0         |[7.9,15.0,6.4,10.0]|[-0.9566950812360182,0.9566950812360182]    |[0.2775403823663211,0.7224596176336789] |1.0       |Iris-virginica |
    +-----------------+---------------+------------+-------------------+--------------------------------------------+----------------------------------------+----------+---------------+
    
    Test Error = 0.03314285714285714
  • 相关阅读:
    【POJ 3162】 Walking Race (树形DP-求树上最长路径问题,+单调队列)
    【POJ 2152】 Fire (树形DP)
    【POJ 1741】 Tree (树的点分治)
    【POJ 2486】 Apple Tree (树形DP)
    【HDU 3810】 Magina (01背包,优先队列优化,并查集)
    【SGU 390】Tickets (数位DP)
    【SPOJ 2319】 BIGSEQ
    【SPOJ 1182】 SORTBIT
    【HDU 5456】 Matches Puzzle Game (数位DP)
    【HDU 3652】 B-number (数位DP)
  • 原文地址:https://www.cnblogs.com/asker009/p/12176982.html
Copyright © 2011-2022 走看看