zoukankan      html  css  js  c++  java
  • K-means

    原理入门视频:https://www.bilibili.com/video/av14601364/

    实现基本功能,从txt中读取数据,根据给定的K值进行分类。

    Java代码:

    package kmeans;
    
    import java.io.BufferedReader;
    import java.io.FileNotFoundException;
    import java.io.FileReader;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.Random;
    
    public class Kmeans {
    
        private int k; // 分成多少簇
        private int m; // 迭代次数
        private int len; // 数据元素个数
        private ArrayList<double[]> center; // 存放中心
        private ArrayList<double[]> clist; // 存放所有点
        private ArrayList<ArrayList<double[]>> cluster; // 存放分类
        private ArrayList<Double> jc; // 误差平方和
    
        public Kmeans(int k) {
            if (k <= 0) {
                k = 1;
            }
            this.k = k;
            m = 0;
            center = new ArrayList<double[]>();
            clist = new ArrayList<double[]>();
            cluster = new ArrayList<ArrayList<double[]>>();
            jc = new ArrayList<Double>();
        }
    
        public int getK() {
            return k;
        }
    
        public void setK(int k) {
            this.k = k;
        }
    
        public int getM() {
            return m;
        }
    
        public ArrayList<double[]> getCenter() {
            return center;
        }
    
        public ArrayList<double[]> getClist() {
            return clist;
        }
    
        public ArrayList<ArrayList<double[]>> getCluster() {
            return cluster;
        }
    
        public void initClist() {
            String[] pointStr = null;
            String str;
            try {
                BufferedReader br = new BufferedReader(new FileReader("simple_k-means.txt"));
                try {
                    while ((str = br.readLine()) != null) {
                        pointStr = str.split("\s+");// 根据空格换行分割
                        double[] pointFlt = new double[pointStr.length];
                        for (int i = 0; i < pointStr.length; i++) {
                            pointFlt[i] = Double.parseDouble(pointStr[i]);
                        }
                        this.clist.add(pointFlt);
                    }
                    this.len = this.clist.size();
                    if (k > len) {
                        k = len;
                    }
                } catch (IOException e) {
                    System.out.println(e.toString());
                }
            } catch (FileNotFoundException e) {
                System.out.println(e.toString());
            }
        }
    
        public void initCenter() {
            Random random = new Random();
            for (int i = 0; i < this.k; i++) {
                int temp = random.nextInt(this.len);
                this.center.add(this.clist.get(temp));
            }
        }
    
        public void initCluster() {
            for (int i = 0; i < this.k; i++) {
                this.cluster.add(new ArrayList<double[]>());
            }
        }
    
        private double getSumSquare(double[] element, double[] center) {
            double x = element[0] - center[0];
            double y = element[1] - center[1];
            double z = x * x + y * y;
            return z;
        }
    
        // 获取距离集合中最小距离的位置
        private int minDistance(double[] distance) {
            double minDis = 0x3f3f3f;
            int minLocation = 0;
            for (int i = 0; i < distance.length; i++) {
                if (minDis > distance[i]) {
                    minDis = distance[i];
                    minLocation = i;
                }
            }
            return minLocation;
        }
    
        // 将当前元素放到最小距离中心相关的簇中
        private void clusterSet() {
            double[] distance = new double[this.k];
            for (int i = 0; i < this.len; i++) {
                for (int j = 0; j < this.k; j++) {
                    distance[j] = Math.sqrt(getSumSquare(this.clist.get(i), this.center.get(j)));
                }
                int minLocation = minDistance(distance);
                this.cluster.get(minLocation).add(this.clist.get(i));
            }
        }
    
        // 求误差
        private void countRule() {
            double jcf = 0;
            for (int i = 0; i < this.cluster.size(); i++) {
                for (int j = 0; j < this.cluster.get(i).size(); j++) {
                    jcf += getSumSquare(this.cluster.get(i).get(j), center.get(i));
                }
            }
            jc.add(jcf);
        }
    
        // 设置新的簇中心方法
        private void setNewCenter() {
            for (int i = 0; i < this.k; i++) {
                int n = this.cluster.get(i).size();
                if (n != 0) {
                    double[] newCenter = { 0, 0 };
                    for (int j = 0; j < n; j++) {
                        newCenter[0] += this.cluster.get(i).get(j)[0];
                        newCenter[1] += this.cluster.get(i).get(j)[1];
                    }
                    newCenter[0] = newCenter[0] / n;
                    newCenter[1] = newCenter[1] / n;
                    this.center.set(i, newCenter);
                }
            }
        }
    
        // 核心过程
        private void kmeans() {
            initClist();
            initCenter();
            initCluster();
            while (true) {
                clusterSet();
                countRule();
                if (m != 0) {
                    if (jc.get(m) - jc.get(m - 1) < 0.001) {
                        break;
                    }
                }
                setNewCenter();
                m++;
                cluster.clear();
                initCluster();
            }
        }
    
        public static void main(String[] args) {
            Kmeans km = new Kmeans(2);  //设置聚类个数
            km.kmeans();
            int count = km.getM();
            ArrayList<double[]> center = km.getCenter();
            ArrayList<ArrayList<double[]>> cluster = km.getCluster();
            System.out.println("迭代次数: " + count);
            System.out.println("----------质心:------------------");
            for (int i = 0; i < center.size(); i++) {
                System.out.println("[" + center.get(i)[0] + "," + center.get(i)[1] + "]");
            }
            System.out.println("----------聚类结果:--------------");
            for (int i = 0; i < cluster.size(); i++) {
                for (int j = 0; j < cluster.get(i).size(); j++) {
                    System.out.print("[" + cluster.get(i).get(j)[0] + "," + cluster.get(i).get(j)[0] + "] ");
                }
                System.out.println();
            }
        }
    }
    java-Kmeans

     传统Kmeans需要给定K值,有两种初始化中心点的方法,一种是在现用的点中,尽量远的随机选择K个点,一种是根据实际问题,自己初始化K个点。

    可能会出现这几种情况:过早的收敛,导致局部最优;某个中心点可能会聚不到点形成空簇。

    为了解决这些问题,提出了改进的Kmeans算法,有需要的继续了解。

  • 相关阅读:
    洛谷P1527 矩阵乘法——二维树状数组+整体二分
    bzoj1503 [NOI2004]郁闷的出纳员——splay
    bzoj4811 [Ynoi2017]由乃的OJ 树链剖分+位运算
    D Dandan's lunch
    C Sleepy Kaguya
    B bearBaby loves sleeping
    A AFei Loves Magic
    II play with GG
    angular的路由和监听路由的变化和用户超时的监听
    ng-disabled 指令
  • 原文地址:https://www.cnblogs.com/flyuz/p/9041240.html
Copyright © 2011-2022 走看看