zoukankan      html  css  js  c++  java
  • kmeans聚类的实现

    Kmeans算法流程

    从数据中随机抽取k个点作为初始聚类的中心,由这个中心代表各个聚类
    计算数据中所有的点到这k个点的距离,将点归到离其最近的聚类里
    调整聚类中心,即将聚类的中心移动到聚类的几何中心(即平均值)处,也就是k-means中的mean的含义
    重复第2步直到聚类的中心不再移动,此时算法收敛
    最后kmeans算法时间、空间复杂度是:
    时间复杂度:上限为O(tKmn),下限为Ω(Kmn)其中,t为迭代次数,K为簇的数目,m为记录数,n为维数 
    空间复杂度:O((m+K)n),其中,K为簇的数目,m为记录数,n为维数


    影响算法准确度的因素

    数据的采集和抽象

    初始的中心选择

    k值的选定

    最大迭代次数

    收敛值

    度量距离的手段


    在进一步阐述初始中心点选择之前,我们应该先确定度量kmeans的算法精确度的方法。一种度量聚类效果的标准是:SSE(Sum of Square Error,误差平方和)
    SSE越小表示数据点越接近于它们的质心,聚类效果也就越好。因为对误差取了平方所以更重视那些远离中心的点。
    一种可以肯定降低SSE的方法是增加簇的个数。但这违背了聚类的目标。因为聚类是在保持目标簇不变的情况下提高聚类的质量。
    现在思路明了了我们首先以缩小SSE为目标改进算法。


    二分K均值

    为了克服k均值算法收敛于局部的问题,提出了二分k均值算法。该算法首先将所有的点作为一个簇,然后将该簇一分为二。之后选择其中一个簇继续划分,选择哪个簇进行划分取决于对其划分是否可以最大程度降低SSE值。
    伪代码如下:
    将所有的点看成一个簇
    当簇数目小于k时
    对于每一个簇
    计算总误差
    在给定的簇上面进行K均值聚类(K=2)
    计算将该簇一分为二后的总误差
    选择使得误差最小的那个簇进行划分操作



    // spectral-cluster.cpp : 定义控制台应用程序的入口点。
    //
    
    #include "stdafx.h"
    #include <iostream> 
    #include<vector>
    #include<time.h>
    #include<cstdlib>
    #include<set>
    
    using namespace std;
    
    //template <typename T>
    class kmeans
    {
    private:
    	vector<vector<int>>dataset;
    	unsigned int k;
    	unsigned int dim;
    	typedef vector<double> Centroid;
    	vector<Centroid> center;
    	vector<set<int>>cluster_ID;
    	vector<Centroid>new_center;
    	vector<set<int>>new_cluster_ID;
    	double threshold;
    	int iter;
    private:
    	void init();
    	void assign();
    	double distance(Centroid cen, int k2);
    	void split(vector<set<int>>&clusters, int kk);
    	void update_centers();
    	bool isfinish();
    	void show_result();
    	void generate_data()
    	{
    		dim = 2;
    		threshold = 0.001;
    		k = 4;
    		for (int i = 0; i < 300; i++)
    		{
    			vector<int>data;
    			data.resize(dim);
    			for (int j = 0; j < dim; j++)
    				data[j] = double(rand()) / double(RAND_MAX + 1.0) * 500;
    			dataset.push_back(data);
    		}
    	}
    public:
    	kmeans()
    	{
    		time_t t;
    		srand(time(&t));
    	}
    	void apply();
    };
    
    //template <typename T>
    void kmeans::init()
    {
    	center.resize(k);
    
    	set<int>bb;
    	for (int i = 0; i < k; i++)
    	{
    
    		int id = double(rand()) / double(RAND_MAX + 1.0)*dataset.size();
    		while (bb.find(id) != bb.end())
    		{
    			id = double(rand()) / double(RAND_MAX + 1.0)*dataset.size();
    		}
    		bb.insert(id);
    		center[i].resize(dim);
    		for (int j = 0; j < dim; j++)
    			center[i][j] = dataset[id][j];
    
    	}
    }
    bool kmeans::isfinish()
    {
    	double error = 0;
    	for (int i = 0; i < k; i++)
    	{
    		for (int j = 0; j < dim; j++)
    			error += pow(center[i][j] - new_center[i][j], 2);
    	}
    	return error < threshold ? true : false;
    }
    void kmeans::assign()
    {
    
    	for (int j = 0; j < dataset.size(); j++)
    	{
    		double mindis = 10000000;
    		int belongto = -1;
    		for (int i = 0; i < k; i++)
    		{
    			double dis = distance(center[i], j);
    			if (dis < mindis)
    			{
    				mindis = dis;
    				belongto = i;
    			}
    		}
    		new_cluster_ID[belongto].insert(j);
    	}
    	for (int i = 0; i < k; i++)
    	{
    		if (new_cluster_ID[i].empty())
    		{
    			split(new_cluster_ID, i);
    		}
    	}
    }
    
    double kmeans::distance(Centroid cen, int k2)
    {
    	double dis = 0;
    	for (int i = 0; i < dim; i++)
    		dis += pow(cen[i] - dataset[k2][i], 2);
    	return sqrt(dis);
    }
    
    void kmeans::split(vector<set<int>>&clusters, int kk)
    {
    	int maxsize = 0;
    	int th = -1;
    	for (int i = 0; i < k; i++)
    	{
    		if (clusters[i].size() > maxsize)
    		{
    			maxsize = clusters[i].size();
    			th = i;
    		}
    	}
    #define DELTA 1
    	vector<double>tpc1, tpc2;
    	tpc1.resize(dim);
    	tpc2.resize(dim);
    	for (int i = 0; i < dim; i++)
    	{
    		tpc2[i] = center[th][i] - DELTA;
    		tpc1[i] = center[th][i] + DELTA;
    	}
    	for (set<int>::iterator it = clusters[th].begin(); it != clusters[th].end(); it++)
    	{
    		double d1 = distance(tpc1, *it);
    		double d2 = distance(tpc2, *it);
    		if (d2 < d1)
    		{
    			clusters[kk].insert(*it);
    		}
    	}
    	_ASSERTE(!clusters[kk].empty());
    	for (set<int>::iterator it = clusters[kk].begin(); it != clusters[kk].end(); it++)
    		clusters[th].erase(*it);
    
    }
    
    void kmeans::update_centers()
    {
    	for (int i = 0; i < k; i++)
    	{
    		Centroid temp;
    		temp.resize(dim);
    		for (set<int>::iterator j = new_cluster_ID[i].begin(); j != new_cluster_ID[i].end(); j++)
    		{
    			for (int m = 0; m < dim; m++)
    				temp[m] += dataset[*j][m];
    		}
    		for (int m = 0; m < dim; m++)
    			temp[m] /= new_cluster_ID[i].size();
    		new_center[i] = temp;
    	}
    }
    
    void kmeans::apply()
    {
    	generate_data();
    	init();
    	new_center.resize(k);
    	new_cluster_ID.resize(k);
    	assign();
    	update_centers();
    	iter = 0;
    	while (!isfinish())
    	{
    		center = new_center;
    		cluster_ID = new_cluster_ID;
    		new_center.clear();
    		new_center.resize(k);
    		new_cluster_ID.clear();
    		new_cluster_ID.resize(k);
    		assign();
    		update_centers();
    		iter++;
    	}
    	show_result();
    }
    
    void kmeans::show_result()
    {
    	int num = 0;
    	for (int i = 0; i < k; i++)
    	{
    		char string[100];
    		sprintf(string, "第个%d簇:", i);
    		cout << string << endl;
    		cout << "中心为 (" << center[i][0] << "," << center[i][1] << ")" << endl;
    		for (set<int>::iterator it = cluster_ID[i].begin(); it != cluster_ID[i].end(); it++)
    		{
    			sprintf(string, "编号%d   ", *it);
    			cout << string << "(" << dataset[*(it)][0] << "," << dataset[*(it)][1] << ")" << endl;
    			num++;
    		}
    
    		cout << endl << endl;
    	}
    
    	_ASSERTE(num == dataset.size());
    }
    
    int _tmain(int argc, _TCHAR* argv[])
    {
    	
    	kmeans km;
    	km.apply();
    
    	system("pause");
    
    
    	return 0;
    }
    


    版权声明:

  • 相关阅读:
    MySQL运维案例分析:Binlog中的时间戳
    身边有位“别人家的程序员”是什么样的体验?
    苹果收取30%过路费_你是顶是踩?
    1019 数字黑洞 (20 分)C语言
    1015 德才论 (25 分)C语言
    1017 A除以B (20 分)C语言
    1014 福尔摩斯的约会 (20 分)
    求n以内最大的k个素数以及它们的和、数组元素循环右移问题、求最大值及其下标、将数组中的数逆序存放、矩阵运算
    1005 继续(3n+1)猜想 (25 分)
    爬动的蠕虫、二进制的前导的零、求组合数、Have Fun with Numbers、近似求PI
  • 原文地址:https://www.cnblogs.com/walccott/p/4956866.html
Copyright © 2011-2022 走看看