zoukankan      html  css  js  c++  java
  • spark机器学习从0到1逻辑斯蒂回归之(四)

     
    逻辑斯蒂回归

    一、概念

    逻辑斯蒂回归(logistic regression)是统计学习中的经典分类方法,属于对数线性模型。logistic回归的因变量可以是二分类的,也可以是多分类的。logistic回归的因变量可以是二分非线性差分方程类的,也可以是多分类的,但是二分类的更为常用,也更加容易解释。所以实际中最为常用的就是二分类的logistic回归。

    二、logistic分布

    设X是连续随机变量,X服从逻辑斯蒂分布是指X具有下列分布函数和密度函数:

     
    分布函数和密度函数

    式中,μ为位置参数,γ>0为形状参数。

    密度函数是脉冲函数

    分布函数是一条Sigmoid曲线(sigmoid curve)即为阶跃函数

     
    Sigmoid曲线

    三、二项逻辑斯谛回归模型

    二项逻辑斯谛回归模型是如下的条件概率分布

     
    回归模型

    x∊Rn是输入,Y∊{0,1}是输出,w∊Rn和b∊R是参数,

    w称为权值向量,b称为偏置,w·x为w和x的内积。

    可以求得P(Y=1|x)和P(Y=0|x)。

    逻辑斯谛回归比较两个条件概率值的大小,将实例x分到概率值较大的那一类。

    四、LR模型参数估计

    可以应用极大似然估计法估计模型参数

     
    极大似然估计

    对L(w)求极大值,得到w的估计值。

    问题就变成了以对数似然函数为目标函数的最优化问题。

    LR学习中通常采用的方法是梯度下降法及拟牛顿法。

    五、代码实现

    我们以iris数据集(https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data)为例进行分析。iris以鸢尾花的特征作为数据来源,数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。

    import org.apache.spark.SparkConf;
    import org.apache.spark.SparkContext;
    import org.apache.spark.api.java.JavaPairRDD;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.mllib.classification.LogisticRegressionModel;
    import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
    import org.apache.spark.mllib.evaluation.MulticlassMetrics;
    import org.apache.spark.mllib.regression.LabeledPoint;
    import org.apache.spark.mllib.util.MLUtils;
    

    5.1、读取数据

    首先,读取文本文件;然后,通过map将每行的数据用“,”隔开,在我们的数据集中,每行被分成了5部分,前4部分是鸢尾花的4个特征,最后一部分是鸢尾花的分类。把这里我们用LabeledPoint来存储标签列和特征列。
    LabeledPoint在监督学习中常用来存储标签和特征,其中要求标签的类型是double,特征的类型是Vector。这里,先把莺尾花的分类进行变换,”Iris-setosa”对应分类0,”Iris-versicolor”对应分类1,其余对应分类2;然后获取莺尾花的4个特征,存储在Vector中。

    SparkConf conf = new  SparkConf().setAppName("LogisticRegression").setMaster("local");
    JavaSparkContext sc = new  JavaSparkContext(conf);
    JavaRDD<String> source =  sc.textFile("data/mllib/iris.data");
    
    JavaRDD<LabeledPoint> data = source.map(line->{
                String[] splits = line.split(",");
                Double label = 0.0;
                if(splits[4].equals("Iris-setosa"))  {
                    label = 0.0;
                }else  if(splits[4].equals("Iris-versicolor")) {
                    label = 1.0;
                }else {
                    label = 2.0;
                }
                return new  LabeledPoint(label,Vectors.dense(Double.parseDouble(splits[0]),
                        Double.parseDouble(splits[1]),
                        Double.parseDouble(splits[2]),
                        Double.parseDouble(splits[3])));
     });
    

    打印数据:

    // 控制台输出结果:
    (0.0,[5.1,3.5,1.4,0.2])
    (0.0,[4.9,3.0,1.4,0.2])
    (0.0,[4.7,3.2,1.3,0.2])
    (0.0,[4.6,3.1,1.5,0.2])
    (0.0,[5.0,3.6,1.4,0.2])
    (0.0,[5.4,3.9,1.7,0.4])
    (0.0,[4.6,3.4,1.4,0.3])
    (0.0,[5.0,3.4,1.5,0.2])
    (0.0,[4.4,2.9,1.4,0.2])
    (0.0,[4.9,3.1,1.5,0.1])
    (0.0,[5.4,3.7,1.5,0.2])
    ... ...
    

    5.2、构建模型:

    // 首先进行数据集的划分,这里划分60%的训练集和40%的测试集:
    JavaRDD<LabeledPoint>[] splits =  data.randomSplit(new double[] {0.6,0.4},11L);
    JavaRDD<LabeledPoint> traning =  splits[0].cache();
    JavaRDD<LabeledPoint> test = splits[1];
    

    构建逻辑斯蒂模型,用set的方法设置参数,比如说分类的数目,这里可以实现多分类逻辑斯蒂模型:

    LogisticRegressionModel model = new LogisticRegressionWithLBFGS().setNumClasses(3).run(traning.rdd());
    

    输出结果:

    org.apache.spark.mllib.classification.LogisticRegressionModel: intercept = 0.0,  numFeatures = 8, numClasses = 3, threshold = 0.5
    

    接下来,调用多分类逻辑斯蒂模型用的predict方法对测试数据进行预测,并把结果保存在MulticlassMetrics中。这里的模型全名为LogisticRegressionWithLBFGS,加上了LBFGS,表示Limited-memory BFGS。其中,BFGS是求解非线性优化问题(L(w)求极大值)的方法,是一种秩-2更新,以其发明者Broyden, Fletcher, Goldfarb和Shanno的姓氏首字母命名。

    JavaPairRDD<Object,Object> predictionAndLables =  test.mapToPair(p->
                new  Tuple2<>(model.predict(p.features()),p.label())
    );
    

    这里,采用了test部分的数据每一行都分为标签label和特征features,然后利用map方法,对每一行的数据进行model.predict(features)操作,获得预测值。并把预测值和真正的标签放到predictionAndLabels中。我们可以打印出具体的结果数据来看一下:

    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (0.0,0.0)
    (1.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (2.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (1.0,1.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (1.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    (2.0,2.0)
    

    可以看出,大部分的预测是对的。其中(2.0,1.0),(1.0,2.0)的预测与实际标签不同。

    5.3、模型评估

    模型预测的准确性打印:

    //准确性打印:
    metrics:0.9615384615384616
  • 相关阅读:
    ASP.NET 2.0的页面缓存功能介绍
    我对针对接口编程的浅解
    接口和抽象类的区别
    面向接口编程到底有什么好处
    泛型编程是什么
    方法的重写、重载及隐藏
    基于事件的编程有什么好处
    Socket Remoting WebService对比
    技术讲座:.NET委托、事件及应用兼谈软件项目开发
    ny589 糖果
  • 原文地址:https://www.cnblogs.com/huanghanyu/p/12916832.html
Copyright © 2011-2022 走看看