zoukankan      html  css  js  c++  java
  • k_means算法+python实现




    一、原理

    K均值算法使用的聚类准则函数是误差平方和准则,通过反复迭代优化聚类结果,使所有样本到各自所属类别的中心的距离平方和达到最小。



    二、算法步骤

    设迭代次数 r = 0

    1. 如果把数据分成k个类,则第一步选前k个点作为第一批聚类中心:Z1(r ),Z2(r )…Zk(r )
    2. 将所有的数据与各个聚类中心求距离(根据实际情况选择欧式、马氏等距离),然后将各数据点分配到离自己最近的聚类中心(相当于分类)。
    3. 对于分好的类,求每个类的重心,作为新的聚类中心。获得新一批的聚类中心Z1(r+1)、Z2(r+1)…Zk(r+1)
    4. 如果新一批的聚类中心与上一批的聚类中心完全相等,则停止迭代,否则重复步骤2~4



    三、实例如下:

    根据调查得到某地10所学校的数据(见下表),试采用k_means算法编写程序,将这些学校按三种类别聚类。
    在这里插入图片描述



    四、python代码实现:

    import numpy as np
    
    '''
        k-means算法
    '''
    
    #标签
    label_set = [
        '学校1','学校2','学校3','学校4','学校5',
        '学校6','学校7','学校8','学校9','学校10'
    ]
    #数据
    data_set = np.array([
        [2088,562.05,42,434],
        [10344.8,4755,76,1279],
        [2700,4100,56,820],
        [3967,3751,67,990],
        [5850.24,6173.25,78,1240],
        [1803.26,5224.99,72,1180],
        [2268,8011,56,800],
        [32000,18000,200,2000],
        [100000,30000,200,1100],
        [173333,60000,420,2552]
    ])
    
    #标准化
    def normal_dataSet(data_set):
        mean = np.mean(data_set,axis=0)
        std = np.std(data_set,axis=0)
        dataSet =  (data_set-mean)/std
        return dataSet
    
    
    
    #计算欧氏距离
    def O_distance(x, y):
        dis = np.sqrt(np.sum(np.square(x-y)))
        return dis
    
    #第一步获取聚类中心(直接获取前k个作为中心)
    def get_cluster_center(dataSet, k):
        Z = []
        for i in range(k):
            Z.append(dataSet[i])
        return np.array(Z)
    
    #根据离聚类中心Z的距离分类
    def classify(dataSet, Z):
        result = {}
        for i in range(len(Z)):
            result['第'+str(i+1)+'类'] = []
        for j in range(len(dataSet)):
            min_class = 0 #初始类
            min_dis = O_distance(dataSet[j],Z[0]) #初始最小的距离
            for i in range(len(Z)):
                dis = O_distance(dataSet[j],Z[i])
                min_dis = dis if dis < min_dis else min_dis
                if(min_dis == dis):
                    min_class = i
            result['第'+str(min_class+1)+'类'].append(j)
        return result
    
    #获取新的聚类中心
    def get_new_cluster_center(result,dataSet):
        Z=[]
        new_result = {}
        #因为result保存的是各类别对应的各点在dataSet的下标
        #需要将下标转化为dataSet中实际值
        for key in result.keys():
            new_result[key] = []
            for index in result[key]:
                new_result[key].append(dataSet[index])
            avg = np.mean(np.array(new_result[key]),axis=0)
            Z.append(avg)
        return np.array(Z)
    
    
    
    #k_means算法,将数据集分成k份
    def k_means(dataSet, k):
        result = {} #分类结果
        Z = get_cluster_center(dataSet, k) #初始的聚类中心
        result = classify(dataSet, Z) #第一次分类
    
        old_Z = Z
        new_Z = get_new_cluster_center(result,dataSet) #获取新的聚类中心
        #迭代
        while ((old_Z!=new_Z).any()):
            result = classify(dataSet, new_Z)
            old_Z = new_Z.copy()
            new_Z = get_new_cluster_center(result,dataSet)
        return result
    
    
        
    
    
    
    # k_means(data_set_1,None,2)
    
    dataSet = normal_dataSet(data_set)#标准化处理
    result = k_means(dataSet ,3)#分步聚类
    #打印分类结果
    for key in result.keys():
        print(key,end=': ')
        for index in result[key]:
            print(label_set[index],end=' ')
        print()
    

    运行结果如下:

    第1类: 学校1
    第2类: 学校8 学校9 学校10 
    第3类: 学校2 学校3 学校4 学校5 学校6 学校7 
    
  • 相关阅读:
    一致性哈希算法
    Discourse 的标签(Tag)只能是小写的原因
    JIRA 链接 bitbucket 提示错误 Invalid OAuth credentials
    JIRA 如何连接到云平台的 bitbucket
    Apache Druid 能够支持即席查询
    如何在 Discourse 中配置使用 GitHub 登录和创建用户
    Apache Druid 是什么
    Xshell 如何导入 PuTTYgen 生成的 key
    windows下配置Nginx支持php
    laravel连接数据库提示mysql_connect() :Connection refused...
  • 原文地址:https://www.cnblogs.com/theory/p/11884309.html
Copyright © 2011-2022 走看看