zoukankan      html  css  js  c++  java
  • 机器学习---kmeans聚类的python实现

    """
    Name: study_kmeans.py
    Author: KX-Lau
    Time: 2020/11/6 16:59
    Desc: 实现kmeans聚类
    """
    
    import math
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn import datasets
    from sklearn.cluster import KMeans
    
    
    # -----------不使用sklearn实现kmeans聚类 -------------
    class MyKmeans:
        def __init__(self, k, n=50):
            self.k = k  # 聚类中心数k
            self.n = n  # 迭代次数
    
        def fit(self, x, centers=None):
            # 1. 随机选择K个点
            if centers is None:
                index = np.random.randint(low=0, high=len(x), size=self.k)  # 随机生成数组, 每个数组元素从low到high的整数, 元素个数为size
                centers = x[index]
    
            inters = 0
            while inters < self.n:
                # 构造k个点的集合
                points_set = {key: [] for key in range(self.k)}
    
                # 2. 遍历所有点point, 将point放入最近的聚类中心的集合中
                for point in x:
                    nearest_index = np.argmin(np.sum((centers - point) ** 2, axis=1) ** 0.5)
                    points_set[nearest_index].append(point)
    
                # 3. 遍历每一个点集, 计算新的聚类中心
                for i_k in range(self.k):
                    centers[i_k] = sum(points_set[i_k]) / len(points_set[i_k])
    
                inters += 1
    
            return points_set, centers
    
    
    """
    iris中文名是鸢尾花卉数据集, 是一类多重变量分析的数据集.
    包含150个样本, 分为3类(山鸢尾Setosa, 变色鸢尾Versicolor, 维吉尼亚鸢尾Virginica), 
    每个类别50个数据, 每个数据包含4个属性(花萼长度, 花萼宽度, 花瓣长度, 花瓣宽度).
    """
    
    iris = datasets.load_iris()
    data = iris['data'][:, :2]
    print(type(data))
    mk = MyKmeans(3)
    point_sets, centers = mk.fit(data)
    
    category1 = np.asarray(point_sets[0])
    category2 = np.asarray(point_sets[1])
    category3 = np.asarray(point_sets[2])
    
    for i, p in enumerate(centers):
        plt.scatter(p[0], p[1], s=200, marker='^', color='yellow', edgecolors='black')
    
    plt.scatter(category1[:, 0], category1[:, 1], color='g')
    plt.scatter(category2[:, 0], category2[:, 1], color='r')
    plt.scatter(category3[:, 0], category3[:, 1], color='b')
    plt.xlim(4, 8)
    plt.ylim(1, 5)
    plt.title('kmeans with k=3')
    plt.show()
    
    # -----------使用sklearn实现kmeans聚类 -------------
    init = np.vstack([data[5], data[109], data[121]])       # 指定初始质心
    kmeans = KMeans(n_clusters=3, init=init, max_iter=100).fit(data)
    labels = kmeans.labels_
    cluster_centers = kmeans.cluster_centers_
    
    c1 = data[labels == 0]
    c2 = data[labels == 1]
    c3 = data[labels == 2]
    
    print('cluster_centers', cluster_centers)
    print('init', init)
    
    plt.figure()
    
    for i, p in enumerate(cluster_centers):
        plt.scatter(p[0], p[1], color='yellow', edgecolors='black', s=200, marker='^')
    
    plt.scatter(c1[:, 0], c1[:, 1], color='g')
    plt.scatter(c2[:, 0], c2[:, 1], color='r')
    plt.scatter(c3[:, 0], c3[:, 1], color='b')
    plt.xlim(4, 8)
    plt.ylim(1, 5)
    plt.title('kmeans using sklearn with k=3')
    plt.show()
    
  • 相关阅读:
    《学习之道》第二章学习方法7看视频
    《学习之道》第二章学习6阅读书籍
    反射详解一
    spring 初始化和销毁的三种方法
    文件读取
    JdbcTemplate 详解二
    JdbcTemplate 详解一
    JdbcTemplate 详解三
    常用commons 工具类依赖配置
    java 8 stream
  • 原文地址:https://www.cnblogs.com/KX-Lau/p/13955098.html
Copyright © 2011-2022 走看看