原理入门视频:https://www.bilibili.com/video/av14601364/
实现基本功能,从txt中读取数据,根据给定的K值进行分类。
Java代码:

package kmeans; import java.io.BufferedReader; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.Random; public class Kmeans { private int k; // 分成多少簇 private int m; // 迭代次数 private int len; // 数据元素个数 private ArrayList<double[]> center; // 存放中心 private ArrayList<double[]> clist; // 存放所有点 private ArrayList<ArrayList<double[]>> cluster; // 存放分类 private ArrayList<Double> jc; // 误差平方和 public Kmeans(int k) { if (k <= 0) { k = 1; } this.k = k; m = 0; center = new ArrayList<double[]>(); clist = new ArrayList<double[]>(); cluster = new ArrayList<ArrayList<double[]>>(); jc = new ArrayList<Double>(); } public int getK() { return k; } public void setK(int k) { this.k = k; } public int getM() { return m; } public ArrayList<double[]> getCenter() { return center; } public ArrayList<double[]> getClist() { return clist; } public ArrayList<ArrayList<double[]>> getCluster() { return cluster; } public void initClist() { String[] pointStr = null; String str; try { BufferedReader br = new BufferedReader(new FileReader("simple_k-means.txt")); try { while ((str = br.readLine()) != null) { pointStr = str.split("\s+");// 根据空格换行分割 double[] pointFlt = new double[pointStr.length]; for (int i = 0; i < pointStr.length; i++) { pointFlt[i] = Double.parseDouble(pointStr[i]); } this.clist.add(pointFlt); } this.len = this.clist.size(); if (k > len) { k = len; } } catch (IOException e) { System.out.println(e.toString()); } } catch (FileNotFoundException e) { System.out.println(e.toString()); } } public void initCenter() { Random random = new Random(); for (int i = 0; i < this.k; i++) { int temp = random.nextInt(this.len); this.center.add(this.clist.get(temp)); } } public void initCluster() { for (int i = 0; i < this.k; i++) { this.cluster.add(new ArrayList<double[]>()); } } private double getSumSquare(double[] element, double[] center) { double x = element[0] - center[0]; double y = element[1] - center[1]; double z = x * x + y * y; return z; } // 获取距离集合中最小距离的位置 private int minDistance(double[] distance) { double minDis = 0x3f3f3f; int minLocation = 0; for (int i = 0; i < distance.length; i++) { if (minDis > distance[i]) { minDis = distance[i]; minLocation = i; } } return minLocation; } // 将当前元素放到最小距离中心相关的簇中 private void clusterSet() { double[] distance = new double[this.k]; for (int i = 0; i < this.len; i++) { for (int j = 0; j < this.k; j++) { distance[j] = Math.sqrt(getSumSquare(this.clist.get(i), this.center.get(j))); } int minLocation = minDistance(distance); this.cluster.get(minLocation).add(this.clist.get(i)); } } // 求误差 private void countRule() { double jcf = 0; for (int i = 0; i < this.cluster.size(); i++) { for (int j = 0; j < this.cluster.get(i).size(); j++) { jcf += getSumSquare(this.cluster.get(i).get(j), center.get(i)); } } jc.add(jcf); } // 设置新的簇中心方法 private void setNewCenter() { for (int i = 0; i < this.k; i++) { int n = this.cluster.get(i).size(); if (n != 0) { double[] newCenter = { 0, 0 }; for (int j = 0; j < n; j++) { newCenter[0] += this.cluster.get(i).get(j)[0]; newCenter[1] += this.cluster.get(i).get(j)[1]; } newCenter[0] = newCenter[0] / n; newCenter[1] = newCenter[1] / n; this.center.set(i, newCenter); } } } // 核心过程 private void kmeans() { initClist(); initCenter(); initCluster(); while (true) { clusterSet(); countRule(); if (m != 0) { if (jc.get(m) - jc.get(m - 1) < 0.001) { break; } } setNewCenter(); m++; cluster.clear(); initCluster(); } } public static void main(String[] args) { Kmeans km = new Kmeans(2); //设置聚类个数 km.kmeans(); int count = km.getM(); ArrayList<double[]> center = km.getCenter(); ArrayList<ArrayList<double[]>> cluster = km.getCluster(); System.out.println("迭代次数: " + count); System.out.println("----------质心:------------------"); for (int i = 0; i < center.size(); i++) { System.out.println("[" + center.get(i)[0] + "," + center.get(i)[1] + "]"); } System.out.println("----------聚类结果:--------------"); for (int i = 0; i < cluster.size(); i++) { for (int j = 0; j < cluster.get(i).size(); j++) { System.out.print("[" + cluster.get(i).get(j)[0] + "," + cluster.get(i).get(j)[0] + "] "); } System.out.println(); } } }
传统Kmeans需要给定K值,有两种初始化中心点的方法,一种是在现用的点中,尽量远的随机选择K个点,一种是根据实际问题,自己初始化K个点。
可能会出现这几种情况:过早的收敛,导致局部最优;某个中心点可能会聚不到点形成空簇。
为了解决这些问题,提出了改进的Kmeans算法,有需要的继续了解。