zoukankan      html  css  js  c++  java
  • K-Means 算法(Java)

    kMeans算法原理见我的上一篇文章。这里介绍K-Means的Java实现方法,参考了Python的实现方法。

    一、数据点的实现

    package com.meachine.learning.kmeans;
    
    import java.util.ArrayList;
    
    /**
     * 数据点,有n维数据
     * 
     */
    public class Point {
        private static int num;
        private int id;
        private int dimensioNum; // 维度
        private ArrayList<Double> values;
        private int clusterId = -1;
        private double minDist = Integer.MAX_VALUE;
    
        public Point() {
    	id = ++num;
    	values = new ArrayList<>();
        }
    
        public void add(double e) {
    	values.add(e);
    	dimensioNum++;
        }
        //------set与get省略----------
    }

    二、数据簇的实现

    package com.meachine.learning.kmeans;
    
    import lombok.EqualsAndHashCode;
    import lombok.Getter;
    import lombok.Setter;
    import lombok.ToString;
    
    /**
     * 簇<br>
     * 数据集合的基本信息
     * 
     */
    public class Cluster {
        // 簇id
        private int clusterId;
        // 属于该簇的点的个数
        private int numOfPoints;
        // 簇中心点的信息
        private Point center;
    
        public Cluster(int id) {
    	this.clusterId = id;
    	numOfPoints = 0;
        }
    
        public Cluster(int id, Point center) {
    	this.clusterId = id;
    	this.center = center;
        }
      //----------set与get省略----------------
    }

    三、计算数据点距离

    package com.meachine.learning.kmeans;
    
    import java.util.List;
    
    /**
     * 计算距离接口
     *
     */
    public interface IDistance<T> {
        public double getDis(List<T> p1, List<T> p2);
    }
    

      

    package com.meachine.learning.kmeans;
    
    import java.util.List;
    
    /**
     * 欧式距离
     *
     */
    public class OujilidDistance<T extends Number> implements IDistance<T> {
    
        public double getDis(List<T> a, List<T> b) {
    	if (a.size() != b.size()) {
    	    throw new IllegalArgumentException("Size not compatible!");
    	}
    	double result = 0;
    	for (int i = 0; i < a.size(); i++) {
    	    result += Math.pow((a.get(i).doubleValue() - b.get(i).doubleValue()), 2);
    	}
    	return Math.sqrt(result);
        }
    
    }

    四、K-Means算法

      

    package com.meachine.learning.kmeans;
    
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.List;
    import java.util.Random;
    
    /**
     * K-Means算法
     * 
     * @author Cang
     *
     */
    public class KMeans {
        // 簇的个数
        private int k;
        // 维度,即多少个变量
        private int dimensioNum;
        // 最大迭代次数
        private int maxItrNum = 100;
        private IDistance<Double> distance;
        private List<Point> points;
        private List<Cluster> clusters = new ArrayList<Cluster>();
        private String dataFileName = "D:/testSet.txt";
    
        public KMeans(int k) {
    	this.k = k;
        }
    
        /**
         * 初始化数据
         */
        public void init() {
    	points = loadDataSet(dataFileName);
    	distance = new OujilidDistance<Double>();
    	initCluster();
        }
    
        /**
         * 加载数据集
         * 
         * @param fileName
         * @return
         */
        private List<Point> loadDataSet(String fileName) {
    	List<Point> points = new ArrayList<>();
    	File file = new File(fileName);
    	BufferedReader reader = null;
    	try {
    	    reader = new BufferedReader(new FileReader(file));
    	    String tempString = null;
    	    int i = 0;
    	    while ((tempString = reader.readLine()) != null) {
    		Point point = new Point();
    		dimensioNum = tempString.split("	").length;
    		for (String data : tempString.split("	")) {
    		    point.add(Double.parseDouble(data));
    		}
    		points.add(point);
    	    }
    	    reader.close();
    	} catch (IOException e) {
    	    e.printStackTrace();
    	}
    	return points;
        }
    
        /**
         * 初始化簇中心
         * 
         * @return
         */
        private void initCluster() {
    	Random ran = new Random();
    	int id = 0;
    	while (id < k) {
    	    Cluster c = new Cluster(++id);
    	    int temp = ran.nextInt(points.size());
    	    c.setCenter(points.get(temp));
    	    clusters.add(c);
    	}
        }
    
        /**
         * kMeans 具体算法
         */
        public void clustering() {
    	boolean finished = false;
    	int count = 0;
    	while (!finished) {
    	    // 寻找最近的中心
    	    finished = true;
    	    for (Point point : points) {
    		for (Cluster cluster : clusters) {
    
    		    double minLen = distance.getDis(cluster.getCenter().getValues(),
    			    point.getValues());
    		    // 更新最小距离
    		    if (minLen < point.getMinDist()) {
    			if (cluster.getClusterId() != point.getClusterId()) {
    			    finished = false;
    			    point.setClusterId(cluster.getClusterId());
    			}
    			point.setMinDist(minLen);
    		    }
    		}
    	    }
    	    System.out.println("Cluster center info:");
    	    for (Cluster string : clusters) {
    		System.out.println(string.getCenter().getValues());
    	    }
    	    // 更改中心的位置
    	    changeCentroids();
    	    // 超过循环次数,则跳出循环
    	    if (++count > maxItrNum) {
    		finished = true;
    	    }
    	}
        }
    
        /**
         * 改变簇中心
         */
        private void changeCentroids() {
    	for (Cluster cluster : clusters) {
    	    ArrayList<Double> newCenterValue = new ArrayList<Double>();
    	    Point newCenterPoint = new Point();
    	    double result = 0;
    	    for (int i = 0; i < dimensioNum; i++) {
    		for (Point point : points) {
    		    if (point.getClusterId() == cluster.getClusterId()) {
    			result += point.getValues().get(i);
    		    }
    		}
    		newCenterValue.add(result / points.size());
    	    }
    	    newCenterPoint.setClusterId(cluster.getClusterId());
    	    newCenterPoint.setValues(newCenterValue);
    	    cluster.setCenter(newCenterPoint);
    	}
        }
    
        public static void main(String[] args) {
    	KMeans kmeans = new KMeans(4);
    	kmeans.init();
    	kmeans.clustering();
        }
    }
    

      

  • 相关阅读:
    [Leetcode] Merge Intervals
    [Leetcode] Sort Colors
    junit
    DBUnit的使用
    xml简介---来自百度百科
    今天开始深入学习XML
    Java 用Myeclipse部署项目基础坏境搭建
    properties配置文件读取方法
    Java web做服务器之间的通信方法
    Java Socket简单的客服端及其服务器端
  • 原文地址:https://www.cnblogs.com/codingexperience/p/5040942.html
Copyright © 2011-2022 走看看