zoukankan      html  css  js  c++  java
  • 基于spark的kmeans算法

    from __future__ import print_function
    
    import sys
    
    import numpy as np
    from pyspark.sql import SparkSession
    
    
    def parseVector(line):
        return np.array([float(x) for x in line.split(' ')])
    
    
    def closestPoint(p, centers):
        bestIndex = 0
        closest = float("+inf")
        for i in range(len(centers)):
            tempDist = np.sum((p - centers[i]) ** 2)
            if tempDist < closest:
                closest = tempDist
                bestIndex = i
        return bestIndex
    
    
    if __name__ == "__main__":
    
        if len(sys.argv) != 4:
            print("Usage: kmeans <file> <k> <convergeDist>", file=sys.stderr)
            sys.exit(-1)
    
        spark = SparkSession
            .builder
            .appName("PythonKMeans")
            .getOrCreate()
    
        lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
        data = lines.map(parseVector).cache()
    //聚类超参数K K
    = int(sys.argv[2])
    //收敛阈值 convergeDist
    = float(sys.argv[3]) //初始化K个中心点 kPoints = data.takeSample(False, K, 1) tempDist = 1.0 while tempDist > convergeDist:
    // map Key: 聚类中心点 Value: (当前点, 数量1) closest
    = data.map( lambda p: (closestPoint(p, kPoints), (p, 1)))
    // reduce Key:聚类中心点, 计算每个聚类中心点下的分布 pointStats
    = closest.reduceByKey( lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1]))
    //map 计算新的中心点 newPoints
    = pointStats.map( lambda st: (st[0], st[1][0] / st[1][1])).collect() tempDist = sum(np.sum((kPoints[iK] - p) ** 2) for (iK, p) in newPoints) for (iK, p) in newPoints: kPoints[iK] = p print("Final centers: " + str(kPoints)) spark.stop()
  • 相关阅读:
    MFC菜单快捷键的应用
    TDD in C++
    Mapping
    初入股市者怎样看盘
    C++随笔分类列表(高级)
    C++代码优化
    IT生活
    十一年炒股的感悟
    框架设计(第2版)CLR Via C#(1)
    Visual Assist X自己常用的快捷功能
  • 原文地址:https://www.cnblogs.com/energy1010/p/9879043.html
Copyright © 2011-2022 走看看