zoukankan      html  css  js  c++  java
  • python spark kmeans demo

    官方的demo

    from numpy import array
    from math import sqrt
    
    from pyspark import SparkContext
    
    from pyspark.mllib.clustering import KMeans, KMeansModel
    
    sc = SparkContext(appName="clusteringExample")
    # Load and parse the data
    data = sc.textFile("/root/spark-2.1.1-bin-hadoop2.6/data/mllib/kmeans_data.txt")
    parsedData = data.map(lambda line: array([float(x) for x in line.split(' ')]))
    
    # Build the model (cluster the data)
    clusters = KMeans.train(parsedData, 2, maxIterations=10, initializationMode="random")
    
    # Evaluate clustering by computing Within Set Sum of Squared Errors
    def error(point):
        center = clusters.centers[clusters.predict(point)]
        return sqrt(sum([x**2 for x in (point - center)]))
    
    WSSSE = parsedData.map(lambda point: error(point)).reduce(lambda x, y: x + y)
    print("Within Set Sum of Squared Error = " + str(WSSSE))
    
    # Save and load model
    #clusters.save(sc, "target/org/apache/spark/PythonKMeansExample/KMeansModel")
    #sameModel = KMeansModel.load(sc, "target/org/apache/spark/PythonKMeansExample/KMeansModel")

     带归一化的例子:

    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.clustering.KMeans
    import org.apache.spark.mllib.regression.LabeledPoint
    import org.apache.spark.sql.functions.{col, udf}
    
    case class DataRow(label: Double, x1: Double, x2: Double)
    val data = sqlContext.createDataFrame(sc.parallelize(Seq(
        DataRow(3, 1, 2),
        DataRow(5, 3, 4),
        DataRow(7, 5, 6),
        DataRow(6, 0, 0)
    )))
    
    val parsedData = data.rdd.map(s => Vectors.dense(s.getDouble(1),s.getDouble(2))).cache()
    val clusters = KMeans.train(parsedData, 3, 20)
    val t = udf { (x1: Double, x2: Double) => clusters.predict(Vectors.dense(x1, x2)) }
    val result = data.select(col("label"), t(col("x1"), col("x2")))
    
    The important part are the last two lines.
    
        Creates a UDF (user-defined function) which can be directly applied to Dataframe columns (in this case, the two columns x1 and x2).
    
        Selects the label column along with the UDF applied to the x1 and x2 columns. Since the UDF will predict closestCluster, after this result will be a Dataframe consisting of (label, closestCluster)

    参考:https://stackoverflow.com/questions/31447141/spark-mllib-kmeans-from-dataframe-and-back-again

    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.clustering._
    
    val rows = data.rdd.map(r => (r.getDouble(1),r.getDouble(2))).cache()
    val vectors = rows.map(r => Vectors.dense(r._1, r._2))
    val kMeansModel = KMeans.train(vectors, 3, 20)
    val predictions = rows.map{r => (r._1, kMeansModel.predict(Vectors.dense(r._1, r._2)))}
    val df = predictions.toDF("id", "cluster")
    df.show

    Create column from RDD

    It's very easy to obtain pairs of ids and clusters in form of RDD:

    val idPointRDD = data.rdd.map(s => (s.getInt(0), Vectors.dense(s.getDouble(1),s.getDouble(2)))).cache()
    val clusters = KMeans.train(idPointRDD.map(_._2), 3, 20)
    val clustersRDD = clusters.predict(idPointRDD.map(_._2))
    val idClusterRDD = idPointRDD.map(_._1).zip(clustersRDD)
    

    Then you create DataFrame from that

    val idCluster = idClusterRDD.toDF("id", "cluster")
    

    It works because map doesn't change order of the data in RDD, which is why you can just zip ids with results of prediction.

    Use UDF (User Defined Function)

    Second method involves using clusters.predict method as UDF:

    val bcClusters = sc.broadcast(clusters)
    def predict(x: Double, y: Double): Int = {
        bcClusters.value.predict(Vectors.dense(x, y))
    }
    sqlContext.udf.register("predict", predict _)
    

    Now we can use it to add predictions to data:

    val idCluster = data.selectExpr("id", "predict(x, y) as cluster")
    

    Keep in mind that Spark API doesn't allow UDF deregistration. This means that closure data will be kept in the memory.

  • 相关阅读:
    漫谈程序员系列:咦,你也在混日子啊
    JVM加载class文件的原理机制
    maven编译的时候排除junit测试类
    mysql之——存储过程 + 游标 + 事务
    JSP页面之${fn:}内置函数
    生成24位字符串ID__IdGenerator.java
    oracle创建用户,表空间,虚拟路径,导入dbf
    eclipse手动修改默认工作空间
    资源
    http协议发送header+body+json及接收解析
  • 原文地址:https://www.cnblogs.com/bonelee/p/7229115.html
Copyright © 2011-2022 走看看