zoukankan      html  css  js  c++  java
  • python spark kmeans demo

    官方的demo

    from numpy import array
    from math import sqrt
    
    from pyspark import SparkContext
    
    from pyspark.mllib.clustering import KMeans, KMeansModel
    
    sc = SparkContext(appName="clusteringExample")
    # Load and parse the data
    data = sc.textFile("/root/spark-2.1.1-bin-hadoop2.6/data/mllib/kmeans_data.txt")
    parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')]))
    
    # Build the model (cluster the data)
    clusters = KMeans.train(parsedData, 2, maxIterations=10, initializationMode="random")
    
    # Evaluate clustering by computing Within Set Sum of Squared Errors
    def error(point):
        center = clusters.centers[clusters.predict(point)]
        return sqrt(sum([x**2 for x in (point - center)]))
    
    WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y)
    print("Within Set Sum of Squared Error = " + str(WSSSE))
    
    # Save and load model
    #clusters.save(sc, "target/org/apache/spark/PythonKMeansExample/KMeansModel")
    #sameModel = KMeansModel.load(sc, "target/org/apache/spark/PythonKMeansExample/KMeansModel")

     带归一化的例子:

    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.clustering.KMeans
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.sql.functions.{col, udf}
    
    case class DataRow(label: Double, x1: Double, x2: Double)
    val data = sqlContext.createDataFrame(sc.parallelize(Seq(
        DataRow(3, 1, 2),
        DataRow(5, 3, 4),
        DataRow(7, 5, 6),
        DataRow(6, 0, 0)
    )))
    
    val parsedData = data.rdd.map(s => Vectors.dense(s.getDouble(1),s.getDouble(2))).cache()
    val clusters = KMeans.train(parsedData, 3, 20)
    val t = udf { (x1: Double, x2: Double) => clusters.predict(Vectors.dense(x1, x2)) }
    val result = data.select(col("label"), t(col("x1"), col("x2")))
    
    The important part are the last two lines.
    
        Creates a UDF (user-defined function) which can be directly applied to Dataframe columns (in this case, the two columns x1 and x2).
    
        Selects the label column along with the UDF applied to the x1 and x2 columns. Since the UDF will predict closestCluster, after this result will be a Dataframe consisting of (label, closestCluster)

    参考:https://stackoverflow.com/questions/31447141/spark-mllib-kmeans-from-dataframe-and-back-again

    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.clustering._
    
    val rows = data.rdd.map(r => (r.getDouble(1),r.getDouble(2))).cache()
    val vectors = rows.map(r => Vectors.dense(r._1, r._2))
    val kMeansModel = KMeans.train(vectors, 3, 20)
    val predictions = rows.map{r => (r._1, kMeansModel.predict(Vectors.dense(r._1, r._2)))}
    val df = predictions.toDF("id", "cluster")
    df.show

    Create column from RDD

    It's very easy to obtain pairs of ids and clusters in form of RDD:

    val idPointRDD = data.rdd.map(s => (s.getInt(0), Vectors.dense(s.getDouble(1),s.getDouble(2)))).cache()
    val clusters = KMeans.train(idPointRDD.map(_._2), 3, 20)
    val clustersRDD = clusters.predict(idPointRDD.map(_._2))
    val idClusterRDD = idPointRDD.map(_._1).zip(clustersRDD)
    

    Then you create DataFrame from that

    val idCluster = idClusterRDD.toDF("id", "cluster")
    

    It works because map doesn't change order of the data in RDD, which is why you can just zip ids with results of prediction.

    Use UDF (User Defined Function)

    Second method involves using clusters.predict method as UDF:

    val bcClusters = sc.broadcast(clusters)
    def predict(x: Double, y: Double): Int = {
        bcClusters.value.predict(Vectors.dense(x, y))
    }
    sqlContext.udf.register("predict", predict _)
    

    Now we can use it to add predictions to data:

    val idCluster = data.selectExpr("id", "predict(x, y) as cluster")
    

    Keep in mind that Spark API doesn't allow UDF deregistration. This means that closure data will be kept in the memory.

  • 相关阅读:
    python基础学习8(浅拷贝与深拷贝)
    适配器模式(Adapter)
    NHibernate的调试技巧和Log4Net配置
    查看表字段的相关的系统信息
    Asp.net MVC 3 开发一个简单的企业网站系统
    ie8 自动设置 兼容性 代码
    同时安装vs2010和VS2012后IEnumerable<ModelClientValidationRule>编译错误
    各种合同样本
    使用远程桌面的朋友可能经常会遇到“超出最大允许连接数”的问题,
    弹出窗口全屏显示:window.showModalDialog与window.open全屏显示
  • 原文地址:https://www.cnblogs.com/bonelee/p/7229115.html
Copyright © 2011-2022 走看看