zoukankan      html  css  js  c++  java
  • 基于spark的kmeans算法

    from __future__ import print_function
    
    import sys
    
    import numpy as np
    from pyspark.sql import SparkSession
    
    
    def parseVector(line):
        return np.array([float(x) for x in line.split(' ')])
    
    
    def closestPoint(p, centers):
        bestIndex = 0
        closest = float("+inf")
        for i in range(len(centers)):
            tempDist = np.sum((p - centers[i]) ** 2)
            if tempDist < closest:
                closest = tempDist
                bestIndex = i
        return bestIndex
    
    
    if __name__ == "__main__":
    
        if len(sys.argv) != 4:
            print("Usage: kmeans <file> <k> <convergeDist>", file=sys.stderr)
            sys.exit(-1)
    
        spark = SparkSession
            .builder
            .appName("PythonKMeans")
            .getOrCreate()
    
        lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0])
        data = lines.map(parseVector).cache()
    //聚类超参数K K
    = int(sys.argv[2])
    //收敛阈值 convergeDist
    = float(sys.argv[3]) //初始化K个中心点 kPoints = data.takeSample(False, K, 1) tempDist = 1.0 while tempDist > convergeDist:
    // map Key: 聚类中心点 Value: (当前点, 数量1) closest
    = data.map( lambda p: (closestPoint(p, kPoints), (p, 1)))
    // reduce Key:聚类中心点, 计算每个聚类中心点下的分布 pointStats
    = closest.reduceByKey( lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1]))
    //map 计算新的中心点 newPoints
    = pointStats.map( lambda st: (st[0], st[1][0] / st[1][1])).collect() tempDist = sum(np.sum((kPoints[iK] - p) ** 2) for (iK, p) in newPoints) for (iK, p) in newPoints: kPoints[iK] = p print("Final centers: " + str(kPoints)) spark.stop()
  • 相关阅读:
    java注解-笔记
    java重载与重写-笔记
    java中(equals与==)- 笔记
    Java迭代与递归-笔记
    C++指针悬挂-笔记
    极速倒入sql记录到excel表格,19个子段5万条记录只需30秒
    利用MCI的方法可以方便的实现光驱门的开关
    如何让你的程序在任务列表隐藏
    如何实现遍历文件夹中的所有文件
    识别操作系统版本
  • 原文地址:https://www.cnblogs.com/energy1010/p/9879043.html
Copyright © 2011-2022 走看看