zoukankan      html  css  js  c++  java
  • LogisticRegression in MLLib (PySpark + numpy+matplotlib可视化)

    参考'LogisticRegression in MLLib' (http://www.cnblogs.com/luweiseu/p/7809521.html)
    通过pySpark MLlib训练logistic模型,再利用Matplotlib作图画出分类边界。

    from pyspark.sql import Row
    from pyspark.sql import HiveContext
    import pyspark
    from IPython.display import display
    import matplotlib
    import matplotlib.pyplot as plt
    
    import os
    os.environ['SPARK_HOME'] ="C:\Users\software\spark-2.1.0-bin-hadoop2.7"
    
    %matplotlib inline 
    
    sc = pyspark.SparkContext(master='local').getOrCreate()
    sqlContext = HiveContext(sc)
    
    # get data
    irisData = sc.textFile("iris.txt")
    
    
    from pyspark.mllib.regression import LabeledPoint
    from pyspark.mllib.linalg import Vectors
    from pyspark.mllib.classification import LogisticRegressionWithLBFGS
    
    
    def toLabeledPoint(line):
        linesp = line.split()
        return LabeledPoint(int(linesp[2]), Vectors.dense(float(linesp[0]), float(linesp[1])))
    
    data = irisData.map(toLabeledPoint)
    
    #Split data into training (60%) and test (40%).
    splits = data.randomSplit([0.6, 0.4],seed=11)
    training = splits[0].cache()
    test = splits[1]
    
    trainer = LogisticRegressionWithLBFGS()
    
    model = trainer.train(training,intercept=True,numClasses=3)
    
    # testdata
    def predicTest(lp):
        label=lp.label
        features=lp.features
        prediction = model.predict(features)
        return (float(prediction), label)
    predictionAndLabels = test.map(predicTest)
    
    
    from pyspark.mllib.evaluation import MulticlassMetrics
    
    #accuracy
    metrics = MulticlassMetrics(predictionAndLabels)
    accuracy = metrics.accuracy
    accuracy
    
    # plot boundary
    import numpy as np
    
    ## meshgrid
    x0, x1 = np.meshgrid(
            np.linspace(0, 8, 500).reshape(-1, 1),
            np.linspace(0, 3.5, 200).reshape(-1, 1),
        )
    X_new = np.c_[x0.ravel(), x1.ravel()]
    
    ## predict
    y_predict = [model.predict(Vectors.dense(X_new_i)) for X_new_i in X_new]
    
    y = data.map(lambda d: d.label).collect()
    X = data.map(lambda d: [d.features[0], d.features[1]]).collect()
    
    y=np.array(y)
    X=np.array(X)
    
    ## draw
    zz = np.array(y_predict).reshape(x0.shape)
    
    plt.figure(figsize=(10, 4))
    plt.plot(X[y==2, 0], X[y==2, 1], "g^", label="Iris-Virginica")
    plt.plot(X[y==1, 0], X[y==1, 1], "bs", label="Iris-Versicolor")
    plt.plot(X[y==0, 0], X[y==0, 1], "yo", label="Iris-Setosa")
    
    from matplotlib.colors import ListedColormap
    custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])
    
    plt.contourf(x0, x1, zz, cmap=custom_cmap, linewidth=5)
    # plt.clabel(contour, inline=1, fontsize=12)
    plt.xlabel("Petal length", fontsize=14)
    plt.ylabel("Petal width", fontsize=14)
    plt.legend(loc="center left", fontsize=14)
    plt.axis([0, 7, 0, 3.5])
    plt.show()
    

    最终结果:

  • 相关阅读:
    用 PHP 和 MySQL 保存和输出图片
    windows XP+Apache+PHP5+MySQL的安装与配置方法
    PHP语法学习笔记
    wap开发环境搭建
    三款免费的PHP加速器:APC、eAccelerator、XCache比较
    window下apache与tomcat整合
    用PHP的ob_start();控制您的浏览器cache!
    一些经典常用的正则表达式
    windows下为apache配置多个站点
    windows下IIS与Tomcat共存的问题
  • 原文地址:https://www.cnblogs.com/luweiseu/p/7826679.html
Copyright © 2011-2022 走看看