zoukankan      html  css  js  c++  java
  • 二分Kmeans的java实现

    刚刚研究了Kmeans。Kmeans是一种十分简单的聚类算法。可是他十分依赖于用户最初给定的k值。它无法发现随意形状和大小的簇。最适合于发现球状簇。他的时间复杂度为O(tkn)。kmeans算法有两个核心点:计算距离的公式&推断迭代停止的条件。一般距採用欧式距离等能够随意。推断迭代停止的条件能够有:

    1) 每一个簇的中心点不再变化则停止迭代

    2)全部簇的点与这个簇的中心点的误差平方和(SSE)的全部簇的总和不再变化

    3)设定人为的迭代次数。观察实验效果。


    当初始簇心选择不好的时候聚类的效果会非常差。

    所以后来又有一个人提出了二分k均值(bisectingkmeans),其核心思路是:将初始的一个簇一分为二计算出误差平方和最大的那个簇,对他进行再一次的二分。直至切分的簇的个数为k个停止。

    事实上质就是不断的对选中的簇做k=2的kmeans切分。

    由于聚类的误差平方和可以衡量聚类性能,该值越小表示数据点月接近于它们的质心。聚类效果就越好。所以我们就须要对误差平方和最大的簇进行再一次的划分。由于误差平方和越大,表示该簇聚类越不好,越有可能是多个簇被当成一个簇了。所以我们首先须要对这个簇进行划分。


    以下是代码,kmeans的原始代码来源于http://blog.csdn.net/cyxlzzs/article/details/7416491,我稍作了一些改动。


    package org.algorithm;
    
    import java.util.ArrayList;
    import java.util.List;
    
    /**
     * 二分k均值。实际上是对一个集合做多次的k=2的kmeans划分。 每次划分后会对sse值较大的簇再进行二分。 终于使得或分出来的簇的个数为k个则停止
     * 
     * 这里利用之前别人写好的一个kmeans的java实现作为基础类。

    * * @author l0979365428 * */ public class BisectingKmeans { private int k;// 分成多少簇 private List<float[]> dataSet;// 当前要被二分的簇 private List<ClusterSet> cluster; // 簇 /** * @param args */ public static void main(String[] args) { // 初始化一个Kmean对象,将k置为10 BisectingKmeans bkm = new BisectingKmeans(5); // 初始化试验集 ArrayList<float[]> dataSet = new ArrayList<float[]>(); dataSet.add(new float[] { 1, 2 }); dataSet.add(new float[] { 3, 3 }); dataSet.add(new float[] { 3, 4 }); dataSet.add(new float[] { 5, 6 }); dataSet.add(new float[] { 8, 9 }); dataSet.add(new float[] { 4, 5 }); dataSet.add(new float[] { 6, 4 }); dataSet.add(new float[] { 3, 9 }); dataSet.add(new float[] { 5, 9 }); dataSet.add(new float[] { 4, 2 }); dataSet.add(new float[] { 1, 9 }); dataSet.add(new float[] { 7, 8 }); // 设置原始数据集 bkm.setDataSet(dataSet); // 运行算法 bkm.execute(); // 得到聚类结果 // ArrayList<ArrayList<float[]>> cluster = bkm.getCluster(); // 查看结果 // for (int i = 0; i < cluster.size(); i++) { // bkm.printDataArray(cluster.get(i), "cluster[" + i + "]"); // } } public BisectingKmeans(int k) { // 比2还小有啥要划分的意义么 if (k < 2) { k = 2; } this.k = k; } /** * 设置需分组的原始数据集 * * @param dataSet */ public void setDataSet(ArrayList<float[]> dataSet) { this.dataSet = dataSet; } /** * 运行算法 */ public void execute() { long startTime = System.currentTimeMillis(); System.out.println("BisectingKmeans begins"); BisectingKmeans(); long endTime = System.currentTimeMillis(); System.out.println("BisectingKmeans running time=" + (endTime - startTime) + "ms"); System.out.println("BisectingKmeans ends"); System.out.println(); } /** * 初始化 */ private void init() { int dataSetLength = dataSet.size(); if (k > dataSetLength) { k = dataSetLength; } } /** * 初始化簇集合 * * @return 一个分为k簇的空数据的簇集合 */ private ArrayList<ArrayList<float[]>> initCluster() { ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>(); for (int i = 0; i < k; i++) { cluster.add(new ArrayList<float[]>()); } return cluster; } /** * Kmeans算法核心过程方法 */ private void BisectingKmeans() { init(); if (k < 2) { // 小于2 则原样输出数据集被觉得是仅仅分了一个簇 ClusterSet cs = new ClusterSet(); cs.setClu(dataSet); cluster.add(cs); } // 调用kmeans进行二分 cluster = new ArrayList(); while (cluster.size() < k) { List<ClusterSet> clu = kmeans(dataSet); for (ClusterSet cl : clu) { cluster.add(cl); } if (cluster.size() == k) break; else// 顺序计算他们的误差平方和 { float maxerro=0f; int maxclustersetindex=0; int i=0; for (ClusterSet tt : cluster) { //计算误差平方和并得出误差平方和最大的簇 float erroe = CommonUtil.countRule(tt.getClu(), tt .getCenter()); tt.setErro(erroe); if(maxerro<erroe) { maxerro=erroe; maxclustersetindex=i; } i++; } dataSet=cluster.get(maxclustersetindex).getClu(); cluster.remove(maxclustersetindex); } } int i=0; for(ClusterSet sc:cluster) { CommonUtil.printDataArray(sc.getClu(),"cluster"+i); i++; } } /** * 调用kmeans得到两个簇。 * * @param dataSet * @return */ private List<ClusterSet> kmeans(List<float[]> dataSet) { Kmeans k = new Kmeans(2); // 设置原始数据集 k.setDataSet(dataSet); // 运行算法 k.execute(); // 得到聚类结果 List<List<float[]>> clus = k.getCluster(); List<ClusterSet> clusterset = new ArrayList<ClusterSet>(); int i = 0; for (List<float[]> cl : clus) { ClusterSet cs = new ClusterSet(); cs.setClu(cl); cs.setCenter(k.getCenter().get(i)); clusterset.add(cs); i++; } return clusterset; } class ClusterSet { private float erro; private List<float[]> clu; private float[] center; public float getErro() { return erro; } public void setErro(float erro) { this.erro = erro; } public List<float[]> getClu() { return clu; } public void setClu(List<float[]> clu) { this.clu = clu; } public float[] getCenter() { return center; } public void setCenter(float[] center) { this.center = center; } } }


    package org.algorithm;
    
    import java.util.List;
    
    /**
     * 把计算距离和误差的公式抽离出来
     * @author l0979365428
     *
     */
    public class CommonUtil {
    
    	/**
    	 * 计算两个点之间的距离
    	 * 
    	 * @param element
    	 *            点1
    	 * @param center
    	 *            点2
    	 * @return 距离
    	 */
    	public static  float distance(float[] element, float[] center) {
    		float distance = 0.0f;
    		float x = element[0] - center[0];
    		float y = element[1] - center[1];
    		float z = x * x + y * y;
    		distance = (float) Math.sqrt(z);
    
    		return distance;
    	}
    	/**
    	 * 求两点误差平方的方法
    	 * 
    	 * @param element
    	 *            点1
    	 * @param center
    	 *            点2
    	 * @return 误差平方
    	 */
    	public static  float errorSquare(float[] element, float[] center) {
    		float x = element[0] - center[0];
    		float y = element[1] - center[1];
    
    		float errSquare = x * x + y * y;
    
    		return errSquare;
    	}
    	/**
    	 * 计算误差平方和准则函数方法
    	 */
    	public static  float countRule( List<float[]> cluster,float[] center) {
    		float jcF = 0;
    	
    			for (int j = 0; j < cluster.size(); j++) {
    				jcF += CommonUtil.errorSquare(cluster.get(j), center);
    
    			}
    		
    	return  jcF;
    	}
    	/**
    	 * 打印数据。測试用
    	 * 
    	 * @param dataArray
    	 *            数据集
    	 * @param dataArrayName
    	 *            数据集名称
    	 */
    	public static  void printDataArray(List<float[]> dataArray, String dataArrayName) {
    		for (int i = 0; i < dataArray.size(); i++) {
    			System.out.println("print:" + dataArrayName + "[" + i + "]={"
    					+ dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
    		}
    		System.out.println("===================================");
    	}
    }
    

    package org.algorithm;
    
    import java.util.ArrayList;
    import java.util.List;
    import java.util.Random;
    
    /**
     * K均值聚类算法
     */
    public class Kmeans {
    	private int k;// 分成多少簇
    	private int m;// 迭代次数
    	private int dataSetLength;// 数据集元素个数,即数据集的长度
    	private List<float[]> dataSet;// 数据集链表
    	private List<float[]> center;// 中心链表
    	private List<List<float[]>> cluster; // 簇
    	private List<Float> jc;// 误差平方和,k越接近dataSetLength,误差越小
    	private Random random;
    
    	public static void main(String[] args) {
    		// 初始化一个Kmean对象,将k置为10
    		Kmeans k = new Kmeans(5);
    		// 初始化试验集
    		ArrayList<float[]> dataSet = new ArrayList<float[]>();
    
    		dataSet.add(new float[] { 1, 2 });
    		dataSet.add(new float[] { 3, 3 });
    		dataSet.add(new float[] { 3, 4 });
    		dataSet.add(new float[] { 5, 6 });
    		dataSet.add(new float[] { 8, 9 });
    		dataSet.add(new float[] { 4, 5 });
    		dataSet.add(new float[] { 6, 4 });
    		dataSet.add(new float[] { 3, 9 });
    		dataSet.add(new float[] { 5, 9 });
    		dataSet.add(new float[] { 4, 2 });
    		dataSet.add(new float[] { 1, 9 });
    		dataSet.add(new float[] { 7, 8 });
    		// 设置原始数据集
    		k.setDataSet(dataSet);
    		// 运行算法
    		k.execute();
    		// 得到聚类结果
    		List<List<float[]>> cluster = k.getCluster();
    		// 查看结果
    		for (int i = 0; i < cluster.size(); i++) {
    			CommonUtil.printDataArray(cluster.get(i), "cluster[" + i + "]");
    		}
    
    	}
    
    	/**
    	 * 设置需分组的原始数据集
    	 * 
    	 * @param dataSet
    	 */
    
    	public void setDataSet(List<float[]> dataSet) {
    		this.dataSet = dataSet;
    	}
    
    	/**
    	 * 获取结果分组
    	 * 
    	 * @return 结果集
    	 */
    
    	public List<List<float[]>> getCluster() {
    		return cluster;
    	}
    
    	/**
    	 * 构造函数,传入须要分成的簇数量
    	 * 
    	 * @param k
    	 *            簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度
    	 */
    	public Kmeans(int k) {
    		if (k <= 0) {
    			k = 1;
    		}
    		this.k = k;
    	}
    
    	/**
    	 * 初始化
    	 */
    	private void init() {
    		m = 0;
    		random = new Random();
    		if (dataSet == null || dataSet.size() == 0) {
    			initDataSet();
    		}
    		dataSetLength = dataSet.size();
    		if (k > dataSetLength) {
    			k = dataSetLength;
    		}
    		center = initCenters();
    		cluster = initCluster();
    		jc = new ArrayList<Float>();
    	}
    
    	/**
    	 * 假设调用者未初始化数据集,则採用内部測试数据集
    	 */
    	private void initDataSet() {
    		dataSet = new ArrayList<float[]>();
    		// 当中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0
    		float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },
    				{ 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },
    				{ 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };
    
    		for (int i = 0; i < dataSetArray.length; i++) {
    			dataSet.add(dataSetArray[i]);
    		}
    	}
    
    	/**
    	 * 初始化中心数据链表,分成多少簇就有多少个中心点
    	 * 
    	 * @return 中心点集
    	 */
    	private ArrayList<float[]> initCenters() {
    		ArrayList<float[]> center = new ArrayList<float[]>();
    		int[] randoms = new int[k];
    		boolean flag;
    		int temp = random.nextInt(dataSetLength);
    		randoms[0] = temp;
    		for (int i = 1; i < k; i++) {
    			flag = true;
    			while (flag) {
    				temp = random.nextInt(dataSetLength);
    				int j = 0;
    
    				while (j < i) {
    					if (temp == randoms[j]) {
    						break;
    					}
    					j++;
    				}
    				if (j == i) {
    					flag = false;
    				}
    			}
    			randoms[i] = temp;
    		}
    
    		for (int i = 0; i < k; i++) {
    			center.add(dataSet.get(randoms[i]));// 生成初始化中心链表
    		}
    		return center;
    	}
    
    	/**
    	 * 初始化簇集合
    	 * 
    	 * @return 一个分为k簇的空数据的簇集合
    	 */
    	private List<List<float[]>> initCluster() {
    		List<List<float[]>> cluster = new ArrayList();
    		for (int i = 0; i < k; i++) {
    			cluster.add(new ArrayList<float[]>());
    		}
    
    		return cluster;
    	}
    
    	/**
    	 * 获取距离集合中最小距离的位置
    	 * 
    	 * @param distance
    	 *            距离数组
    	 * @return 最小距离在距离数组中的位置
    	 */
    	private int minDistance(float[] distance) {
    		float minDistance = distance[0];
    		int minLocation = 0;
    		for (int i = 1; i < distance.length; i++) {
    			if (distance[i] < minDistance) {
    				minDistance = distance[i];
    				minLocation = i;
    			} else if (distance[i] == minDistance) // 假设相等,随机返回一个位置
    			{
    				if (random.nextInt(10) < 5) {
    					minLocation = i;
    				}
    			}
    		}
    
    		return minLocation;
    	}
    
    	/**
    	 * 核心,将当前元素放到最小距离中心相关的簇中
    	 */
    	private void clusterSet() {
    		float[] distance = new float[k];
    		for (int i = 0; i < dataSetLength; i++) {
    			for (int j = 0; j < k; j++) {
    				distance[j] = CommonUtil
    						.distance(dataSet.get(i), center.get(j));
    
    			}
    			int minLocation = minDistance(distance);
    
    			cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心相关的簇中
    
    		}
    	}
    
    	/**
    	 * 计算误差平方和准则函数方法
    	 */
    	private void countRule() {
    		float jcF = 0;
    		for (int i = 0; i < cluster.size(); i++) {
    			for (int j = 0; j < cluster.get(i).size(); j++) {
    				jcF += CommonUtil.errorSquare(cluster.get(i).get(j), center
    						.get(i));
    
    			}
    		}
    		jc.add(jcF);
    	}
    
    	/**
    	 * 设置新的簇中心方法
    	 */
    	private void setNewCenter() {
    		for (int i = 0; i < k; i++) {
    			int n = cluster.get(i).size();
    			if (n != 0) {
    				float[] newCenter = { 0, 0 };
    				for (int j = 0; j < n; j++) {
    					newCenter[0] += cluster.get(i).get(j)[0];
    					newCenter[1] += cluster.get(i).get(j)[1];
    				}
    				// 设置一个平均值
    				newCenter[0] = newCenter[0] / n;
    				newCenter[1] = newCenter[1] / n;
    				center.set(i, newCenter);
    			}
    		}
    	}
    
    	public List<float[]> getCenter() {
    		return center;
    	}
    
    	public void setCenter(List<float[]> center) {
    		this.center = center;
    	}
    
    
    	/**
    	 * Kmeans算法核心过程方法
    	 */
    	private void kmeans() {
    		init();
    
    		// 循环分组。直到误差不变为止
    		while (true) {
    			clusterSet();
    			countRule();
    
    			if (m != 0) {
    				if (jc.get(m) - jc.get(m - 1) == 0) {
    					break;
    				}
    			}
    
    			setNewCenter();
    
    			m++;
    			cluster.clear();
    			cluster = initCluster();
    		}
    
    	}
    
    	/**
    	 * 运行算法
    	 */
    	public void execute() {
    		long startTime = System.currentTimeMillis();
    		System.out.println("kmeans begins");
    		kmeans();
    		long endTime = System.currentTimeMillis();
    		System.out.println("kmeans running time=" + (endTime - startTime)
    				+ "ms");
    		System.out.println("kmeans ends");
    		System.out.println();
    	}
    }

    分别运行两种聚类算法都使得k=5结果例如以下:

    Kmeans:

    print:cluster[0]={5.0,6.0}
    print:cluster[1]={4.0,5.0}
    print:cluster[2]={6.0,4.0}
    ===================================
    print:cluster[0]={1.0,2.0}
    print:cluster[1]={3.0,3.0}
    print:cluster[2]={3.0,4.0}
    print:cluster[3]={4.0,2.0}
    ===================================
    print:cluster[0]={7.0,8.0}
    ===================================
    print:cluster[0]={8.0,9.0}
    ===================================
    print:cluster[0]={3.0,9.0}
    print:cluster[1]={5.0,9.0}
    print:cluster[2]={1.0,9.0}
    ===================================

    BisectingKmeans:
    print:cluster0[0]={8.0,9.0}
    print:cluster0[1]={7.0,8.0}
    ===================================
    print:cluster1[0]={3.0,4.0}
    print:cluster1[1]={5.0,6.0}
    print:cluster1[2]={4.0,5.0}
    print:cluster1[3]={6.0,4.0}
    ===================================
    print:cluster2[0]={1.0,2.0}
    print:cluster2[1]={3.0,3.0}
    print:cluster2[2]={4.0,2.0}
    ===================================
    print:cluster3[0]={1.0,9.0}
    ===================================
    print:cluster4[0]={3.0,9.0}
    print:cluster4[1]={5.0,9.0}
    ===================================
    

    如上有理解问题还请指正。



    參考文献:

    http://blog.csdn.net/zouxy09/article/details/17590137

    http://wenku.baidu.com/link?url=e6sXeX_txPMnNnYy8W28mP-HSD2Lk8cQGbW-4esipqu95r-P4Ke2QPeHLhfBtoie6agplav6VtVwxlyg-jf_5byHJ_Ce93ARqA6U9rn6XKK

    《机器学习实战》

  • 相关阅读:
    Linux rabbitmq的安装和安装amqp的php插件
    php 安装xdebug扩展
    php 安装pdo_mysql 扩展
    php 安装redis扩展
    给定a、b两个文件,各存放50亿个url,每个url各占用64字节,内存限制是4G,如何找出a、b文件共同的url?
    php 实现多线程
    PHP-Fcgi下PHP的执行时间设置方法
    nginx php 安装
    PHP-fpm启动时 出现 PHP Warning: PHP Startup: Invalid library (maybe not a PHP library) 'fileinfo.so' in Unknown on line 0
    CentOS 6.7 编译PHP7 make时出现错误:undefined reference to `libiconv_close’
  • 原文地址:https://www.cnblogs.com/yangykaifa/p/6782990.html
Copyright © 2011-2022 走看看