zoukankan      html  css  js  c++  java
  • 算法

    Kmeans++算法

    Kmeans++算法,主要可以解决初始中心的选择问题,不可解决k的个数问题。

    Kmeans++主要思想是选择的初始聚类中心要尽量的远。

    做法:

    1.    在输入的数据点中随机选一个作为第一个聚类中心。

    2.    对于所有数据点,计算它与已有的聚类中心的最小距离D(x)

    3.    选择一个数据点作为新增的聚类中心,选择原则:D(x)较大的点被选为聚类中心的概率较大。

    4.    重复2~3步骤直到选出k个聚类中心。

    5.    运行Kmeans算法。

    package com.lfy.main;
    
    import java.util.ArrayList;
    import java.util.List;
    import java.util.Random;
    
    /**
     * K均值聚类算法
     */
    public class Kmeans {
        private int numOfCluster;// 分成多少簇
        private int timeOfIteration;// 迭代次数
        private int dataSetLength;// 数据集元素个数,即数据集的长度
        private ArrayList<float[]> dataSet;// 数据集
        private ArrayList<float[]> center;// 质心
        private ArrayList<ArrayList<float[]>> cluster; //
        private ArrayList<Float> sumOfErrorSquare;// 误差平方和
        private Random random;
    
        /**
         * 设置需分组的原始数据集
         *
         * @param dataSet
         */
    
        public void setDataSet(ArrayList<float[]> dataSet) {
            this.dataSet = dataSet;
        }
    
        /**
         * 获取结果分组
         *
         * @return 结果集
         */
    
        public ArrayList<ArrayList<float[]>> getCluster() {
            return cluster;
        }
    
        /**
         * 构造函数,传入需要分成的簇数量
         *
         * @param numOfCluster
         *    簇数量,若numOfCluster<=0时,设置为1,若numOfCluster大于数据源的长度时,置为数据源的长度
         */
        public Kmeans(int numOfCluster) {
            if (numOfCluster <= 0) {
                numOfCluster = 1;
            }
            this.numOfCluster = numOfCluster;
        }
    
        /**
         * 初始化
         */
        private void init() {
            timeOfIteration = 0;
            random = new Random();
            //如果调用者未初始化数据集,则采用内部测试数据集
            if (dataSet == null || dataSet.size() == 0) {
                initDataSet();
            }
            dataSetLength = dataSet.size();
            //若numOfCluster大于数据源的长度时,置为数据源的长度
            if (numOfCluster > dataSetLength) {
                numOfCluster = dataSetLength;
            }
            center = initCenters();
            cluster = initCluster();
            sumOfErrorSquare = new ArrayList<Float>();
            //查看init质心的选取情况
            printDataArray(center,"initCenter");
        }
    
        /**
         * 如果调用者未初始化数据集,则采用内部测试数据集
         */
        private void initDataSet() {
            dataSet = new ArrayList<float[]>();
            // 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0
            float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },
                    { 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },
                    { 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };
    
            for (int i = 0; i < dataSetArray.length; i++) {
                dataSet.add(dataSetArray[i]);
            }
        }
    
        /**
         * 随机选取k个质点
         * 初始化中心点,分成多少簇就有多少个中心点
         *
         * @return 中心点集
         */
        private ArrayList<float[]> initCenters() {
            ArrayList<float[]> center = new ArrayList<float[]>();
            int[] randoms = new int[numOfCluster];
            int temp = random.nextInt(dataSetLength);
            randoms[0] = temp;
            //----------------------
            List<Integer> list=new ArrayList<Integer>();
            list.add(temp);
            //randoms数组中存放dataSet数据集的不同的下标
            for (int i = 1; i < numOfCluster; i++) {
    //            while (true) {
    //                temp = random.nextInt(dataSetLength);
    //
    //                int j=0;
    //                for(; j<i; j++){
    //                    if(randoms[j] == temp){
    //                        break;
    //                    }
    //                }
    //                //没有与任何一个已经选定的质心重复
    //                //跳出外层循环,设定一个随机质心
    //                if (j == i) {
    //                    break;
    //                }
    //            }
                //----------------------
                ArrayList<float[]> ltemp=new ArrayList<float[]>();
                //从剩下的点中继续找质点
                for (int k = 0; k < dataSetLength; k++) {
                    //如果该点还没有被选择为质点,则计算它与已有的所有质点的最小距离
                    if(!list.contains(k)) {
                        float[] distance = new float[numOfCluster];
                        for (int j = 0; j < list.size(); j++) {
                            //某点k到已有中心点的距离
                            distance[j] = distance(dataSet.get(k), dataSet.get(list.get(j)));
                        }
                        int j = minDistance(distance);
                        float[] f={0,0};
                        f[0]=k;
                        f[1]=distance[j];
                        ltemp.add(f);
                    }
                }
                int m=maxDistance(ltemp);
                temp=(int) ltemp.get(m)[0];
                list.add(temp);
                //----------------------
                randoms[i] = temp;
            }
    
            for (int i = 0; i < numOfCluster; i++) {
                center.add(dataSet.get(randoms[i]));// 生成初始化中心点集
            }
            return center;
        }
    
        /**
         * 初始化簇集合
         *
         * @return 一个分为k簇的空数据的簇集合
         */
        private ArrayList<ArrayList<float[]>> initCluster() {
            ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();
            for (int i = 0; i < numOfCluster; i++) {
                cluster.add(new ArrayList<float[]>());
            }
            return cluster;
        }
    
        /**
         * 计算两个点之间的距离
         *
         * @param element
         *            点1
         * @param center
         *            点2
         * @return 距离
         */
        private float distance(float[] element, float[] center) {
            float distance = 0.0f;
            float x = element[0] - center[0];
            float y = element[1] - center[1];
            float z = x * x + y * y;
            distance = (float) Math.sqrt(z);
    
            return distance;
        }
    
        /**
         * 获取距离集合中最小距离的位置
         *
         * @param distance
         *            距离数组
         * @return 最小距离在距离数组中的位置
         */
        private int minDistance(float[] distance) {
            float minDistance = distance[0];
            int minLocation = 0;
            for (int i = 1; i < distance.length; i++) {
                if (distance[i] <= minDistance) {
                    minDistance = distance[i];
                    minLocation = i;
                }
            }
            return minLocation;
        }
        
        /**
         * 获取距离集合中最小距离的最大的位置
         *
         * @param distance
         *            各点最小距离数组
         * @return 各点最小距离在距离数组中的位置
         */
        private int maxDistance(ArrayList<float[]> distance) {
            float[] maxDistance = distance.get(0);
            int maxLocation = 0;
            for (int i = 1; i < distance.size(); i++) {
                if (distance.get(i)[1] >= maxDistance[1]) {
                    maxDistance = distance.get(i);
                    maxLocation = i;
                }
            }
            return maxLocation;
        }
    
        /**
         * 核心,将当前元素放到最小距离的簇中
         */
        private void clusterSet() {
            float[] distance = new float[numOfCluster];
            for (int i = 0; i < dataSetLength; i++) {
                for (int j = 0; j < numOfCluster; j++) {
                    //计算数据集点与所有中心点的距离
                    distance[j] = distance(dataSet.get(i), center.get(j));
                }
                int j = minDistance(distance);
                // 核心,将当前元素放到最小距离中心相关的簇中
                cluster.get(j).add(dataSet.get(i));
            }
        }
    
        /**
         * 求族中各点到其中心点距离的平方,即误差平方
         *
         * @param element
         *            点1
         * @param center
         *            点2
         * @return 误差平方
         */
        private float errorSquare(float[] element, float[] center) {
            float x = element[0] - center[0];
            float y = element[1] - center[1];
    
            float errSquare = x * x + y * y;
    
            return errSquare;
        }
    
        /**
         * 计算一次迭代误差平方和
         */
        private void countRule() {
            float jcF = 0;
            for (int i = 0; i < cluster.size(); i++) {
                for (int j = 0; j < cluster.get(i).size(); j++) {
                    jcF += errorSquare(cluster.get(i).get(j), center.get(i));
                }
            }
            sumOfErrorSquare.add(jcF);
        }
    
        /**
         * 设置新的簇中心方法
         */
        private void setNewCenter() {
            for (int i = 0; i < numOfCluster; i++) {
                int n = cluster.get(i).size();
                if (n != 0) {
                    float[] newCenter = { 0, 0 };
                    for (int j = 0; j < n; j++) {
                        newCenter[0] += cluster.get(i).get(j)[0];
                        newCenter[1] += cluster.get(i).get(j)[1];
                    }
                    // 设置一个平均值
                    newCenter[0] = newCenter[0] / n;
                    newCenter[1] = newCenter[1] / n;
                    center.set(i, newCenter);
                }
            }
            printDataArray(center,"newCenter");
        }
    
        /**
         * 打印数据,测试用
         *
         * @param dataArray
         *            数据集
         * @param dataArrayName
         *            数据集名称
         */
        public void printDataArray(ArrayList<float[]> dataArray,
                                   String dataArrayName) {
            for (int i = 0; i < dataArray.size(); i++) {
                System.out.println("print:" + dataArrayName + "[" + i + "]={"
                        + dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
            }
            System.out.println("===================================");
        }
    
        /**
         * Kmeans算法核心过程方法
         */
        private void kmeans() {
            init();
    
            // 循环分组,直到误差不变为止
            while (true) {
                clusterSet();
    
                countRule();
    
                // 误差不变了,分组完成
                if (timeOfIteration != 0) {
                    if (sumOfErrorSquare.get(timeOfIteration) - sumOfErrorSquare.get(timeOfIteration - 1) == 0) {
                        break;
                    }
                }
                //设置各簇新的质心,继续迭代
                setNewCenter();
                timeOfIteration++;
                cluster.clear();
                cluster = initCluster();
            }
            System.out.println("note:the times of repeat:timeOfIteration="+timeOfIteration);//输出迭代次数
        }
    
        /**
         * 执行算法
         */
        public void execute() {
            long startTime = System.currentTimeMillis();
            System.out.println("kmeans begins");
            kmeans();
            long endTime = System.currentTimeMillis();
            System.out.println("kmeans running time=" + (endTime - startTime)
                    + "ms");
            System.out.println("kmeans ends");
            System.out.println();
        }
    }
    package com.lfy.main;
    
    import java.util.ArrayList;
    
    public class KmeansTest {
        public  static void main(String[] args)
        {
            //初始化一个Kmean对象,设置k值
            Kmeans k=new Kmeans(3);
            ArrayList<float[]> dataSet=new ArrayList<float[]>();
            
            dataSet.add(new float[]{3,4});
            dataSet.add(new float[]{4,4});
            dataSet.add(new float[]{3,3});
            dataSet.add(new float[]{4,3});
            //
            dataSet.add(new float[]{0,2});
            dataSet.add(new float[]{1,2});
            dataSet.add(new float[]{0,1});
            dataSet.add(new float[]{1,1});
            //
            dataSet.add(new float[]{3,1});
            dataSet.add(new float[]{3,0});
            dataSet.add(new float[]{5,0});
            dataSet.add(new float[]{4,0});
            dataSet.add(new float[]{4,1});
    
            //设置原始数据集
            k.setDataSet(dataSet);
            //执行算法
            k.execute();
            //得到聚类结果
            ArrayList<ArrayList<float[]>> cluster=k.getCluster();
            //查看结果
            for(int i=0;i<cluster.size();i++)
            {
                k.printDataArray(cluster.get(i), "cluster["+i+"]");
            }
    
        }
    }
  • 相关阅读:
    Nginx的proxy_cache缓存
    linux服务器优化
    LVS+keepalived负载均衡实战
    bash history(history命令)
    APACHE默认模块功能说明
    MySQL配置文件例子翻译
    Microsoft JET Database Engine (0x80004005) 未指定的错误的完美解决[转贴]
    entity framework 新增 修改 删除 查询
    Flash Builder 找不到所需的 Adobe Flash Player 调试器版本
    sql server 2008 远程连接
  • 原文地址:https://www.cnblogs.com/ZeroMZ/p/11827690.html
Copyright © 2011-2022 走看看