zoukankan      html  css  js  c++  java
  • kmeans聚类源代码

    代码是在weka上二次开发的,但没有使用原来的kmeans代码,只是用了它的数据类Intances,先说下与它相关的几点东西。

    一、KMeans算法简介

    输入:聚类个数k,以及包含 n个数据对象的数据库。
    输出:满足方差最小标准的k个聚类。

    处理流程:       
    1)从 n个数据对象任意选择 k 个对象作为初始聚类中心.

    2)根据每个聚类对象的均值(中心对象),计算每个对象与这些中心对象的距离;并根据最小距离重新对相应对象进行划分;
    3)重新计算每个(有变化)的聚类的均值。

    4)重复(2)(3),直到聚类不发生改变。

    划分为 k个聚类以便使得所获得的聚类满足:同一聚类中的对象相似度较高;而不同聚类中的对象相似度较小。聚类相似度是利用各聚类中对象的均值所获得一个中心对象(引力中心)来进行计算的。
        k-means
     算法的工作过程说明如下:首先从n个数据对象任意选择 k 个对象作为初始聚类中心;而对于所剩下其它对象,则根据它们与这些聚类中心的相似度(距离),分别将它们分配给与其最相似的(聚类中心所代表的)聚类;然后再计算每个所获新聚类的聚类中心(该聚类中所有对象的均值);不断重复这一过程直到标准测度函数开始收敛为止。一般都采用均方差作为标准测度函数. k个聚类具有以下特点:各聚类本身尽可能的紧凑,而各聚类之间尽可能的分开。

     

    二、weka相关介绍

    1、Instances

    Instances是数据集类,是样本或实例的集合。

    Instances有个数据头header,包含除了样本Instance之外的信息,比如attributes的信息。

    //data是Instances类型
    Instances ins = new Instances(data, 0);
    //用data的数据集信息创建一个新的Instances类,但不包含data的样本

    2、Instance

    Instance是一个接口,对应的是一个样本或实例。

    因为不同项目中数据集是不一样的,也就不能固定Instance的具体属性,所以写成接口。在自己的项目中,应该实现这个Instance

    3、Attribute

    属性类,对应Instance的一列。

    每个Attribute保存的是Instance这一列可取的值。

    a、如果是属性是数值类型,这个Attribute就用Numeric类型表示

    b、如果是String类型,这个Attribute保存了一个向量,向量的每个值是属性可以取到的String值。

    所以,Attribute不保存某个Instance的某个属性状态,而是数据集Instances的某个属性的可取范围,也就是这个属性的值域。

    所以,虽然Instance能get到某个Attribute,当不能直接从这个Attribute获取这个属性值。属性值可以直接从Instance获取。

    //获取第i个属性的值,用String表示
    instance.toString(i)

    如果属性值是数值类型,可以直接获取,不用toString(i)

    //获取第3个属性的值,从0开始计数。返回double
    instance.value(3)

    4、InstanceQuery

    我的数据是在数据库中,所以我就用InstanceQuery这个类,而且可以避免我自己去实现Instance类。

    可以参考下官方文档:http://weka.wikispaces.com/Use+WEKA+in+your+Java+code

    首先要修改DatabaseUtils.props文档,在weka.experiment包下面。

    在Eclipse IDE下没法直接修改,我用2345好压找到文件,修改了保存回去,就ok了。我连的是oracle。

    主要是修改DatabaseUtils.props的这两行:

    # JDBC driver (comma-separated list)
    jdbcDriver=oracle.jdbc.driver.OracleDriver
    
    # database URL
    jdbcURL=jdbc:oracle:thin:@192.168.2.67:1521:orcl

    然后就是在代码中写了:

        public static Instances getData() throws Exception {
            InstanceQuery query = getQuery();        
            String sql = "..............................................................";
            Instances data = query.retrieveInstances(sql);
            if (data.numInstances() <= 0)
                throw new Exception("data size is 0;");
            clusterCentroids = new Instances(data, numClusters);
            return data;
        }
        
        public static InstanceQuery getQuery() {
            InstanceQuery query = null;
            try {
                query = new InstanceQuery();
            } catch (Exception e) {
                e.printStackTrace();
            }
            query.setDatabaseURL("jdbc:oracle:thin:@192.168.2.67:1521:orcl");
            query.setUsername("...");
            query.setPassword("...");
            return query;
        }

     

    三、代码

    import java.io.BufferedWriter;
    import java.io.File;
    import java.io.FileWriter;
    import java.io.IOException;
    import java.util.Enumeration;
    import java.util.Random;import weka.core.Instance;
    import weka.core.Instances;
    import weka.core.Utils;
    import weka.experiment.InstanceQuery;
    
    public class DisKmeans {
        
        private static int numClusters = 20;
        private static int maxInteration = 600;
        private static Instances clusterCentroids;
        private static Instances[] kmeansResult;
        
        public static void main(String[] args) throws Exception {
            Instances data = getData();
            System.out.println("start");
            System.out.println(data.size());
            System.out.println(data.instance(1));
            Instance instance = data.instance(1);
            for (int i=0; i<instance.numAttributes(); i++) {
    //            System.out.println(instance.attribute(i).value(0));
                System.out.println(instance.toString(i));
            }
            System.out.println("~~~~~~~~~" + instance.value(3));
            
            getKMeansResult(data);
            printClusterResult();
            updateToDB();
        }


      /** * 获取k个中心点 * @param data * @return * @throws Exception */ public static Instances getCentroid(Instances data) throws Exception { System.out.println("-------------聚类中心-----------"); if (data.numInstances() == 0) {// 判断输入的数据文件是否为空。 throw new Exception("输入数据为空值!请检查数据集文件"); } Random random = new Random(10); int insIndex = 0; clusterCentroids = new Instances(data, numClusters); for (int i=data.numInstances()-1; i>=0 ; i--) { insIndex = random.nextInt(i);//保证i不会超过数组上线 clusterCentroids.add(data.instance(insIndex)); if (clusterCentroids.numInstances() == numClusters) { break; } } // System.out.println(clusterCentroids); printInstances(clusterCentroids); System.out.println("-----------聚类中心结束-------------"); return clusterCentroids; } /** * 根据初始中心点,对数据进行第一次聚类,结果为初始聚类结果 * @param data * @param centroids * @return */ public static Instances[] createCluster(Instances data, Instances centroids) { Instances[] newData = new Instances[centroids.numInstances()]; for (int i=0; i<newData.length; i++) { //初始化newData数组中的Instances实例。取data的header作为实例头,初始容量为0; newData[i] = new Instances(data, 0); newData[i].add(centroids.instance(i)); } for (int i=0; i<data.numInstances(); i++) { double[] tempDis = new double[centroids.numInstances()]; for (int j=0; j<centroids.numInstances(); j++) { if (!equalsInstance(data.instance(i), centroids.instance(j))) { // 重写Instance的equals方法。见EqualsInstance类 // 用欧式距离计算数据中其他实例和中心点实例的距离。 tempDis[j] = computeDistance(data.instance(i),centroids.instance(j)); } }//end for(j) int smallIndex = Utils.minIndex(tempDis); newData[smallIndex].add(data.instance(i)); }//end for(i) return newData; } /** * 得到cluster的中心点 * @param data * @return */ public static Instance meanCentroid(Instances data) { double sumValue = 0.0; double avgValue = 0.0; if (data.numInstances() <= 0) return null; if (data.numInstances() == 1) return data.firstInstance(); Instance meanIns = data.firstInstance(); for (int i=3; i<data.numAttributes(); i++) { for (int j=1; j<data.numInstances(); j++) { sumValue += data.instance(j).value(i); } avgValue = sumValue / data.numInstances(); meanIns.setValue(data.attribute(i), avgValue); sumValue = 0.0; avgValue = 0.0; } return meanIns; } private static Instances updateCentroids(Instances centroids, Instance instance, int index) { centroids.add(instance); int temp = centroids.numInstances() - 1; centroids.swap(index, temp); centroids.delete(temp); return centroids; } public static Instances[] getKMeansResult(Instances data) throws Exception { System.out.println("聚类开始"); clusterCentroids = getCentroid(data);//k个中心点 kmeansResult = createCluster(data, clusterCentroids);//第一次聚类 for (int i=0; i<maxInteration; i++) { Instances newCentroids = new Instances(clusterCentroids); for (int j=0; j<kmeansResult.length; j++) { Instances cluster = kmeansResult[j]; Instance tmpIns = meanCentroid(cluster); newCentroids = updateCentroids(newCentroids, tmpIns, j); } System.out.println("迭代......" + (i + 1)); if (equalsInstance(newCentroids, clusterCentroids)) { System.out.println("中心点集合不再变化!迭代次数" + (i + 1)); break; } clusterCentroids = newCentroids; kmeansResult = createCluster(data, clusterCentroids);//根据中心点,重新聚类 if (i==maxInteration-2 || i==maxInteration-1) { printInstances(clusterCentroids); } } System.out.println("~~~~~~~~~聚类结束~~~~~~~~~"); printInstances(clusterCentroids); return kmeansResult; } public static void printClusterResult() throws IOException { // System.out.println("聚类结果:"); int i=0; BufferedWriter bw = new BufferedWriter(new FileWriter(new File("result.txt"))); for (Instances ins : kmeansResult) { // System.out.println("cluster " + i++ + ":"); bw.append("cluster " + i++ + ": "); @SuppressWarnings("rawtypes") Enumeration en = ins.enumerateInstances(); while (en.hasMoreElements()) { // System.out.println(en.nextElement()); bw.append(en.nextElement().toString()); bw.append(" "); } // System.out.println(); bw.append(" "); } bw.flush(); bw.close(); } public static Instances getData() throws Exception { InstanceQuery query = getQuery(); String sql = "..................................................................."; Instances data = query.retrieveInstances(sql); if (data.numInstances() <= 0) throw new Exception("data size is 0;"); clusterCentroids = new Instances(data, numClusters); return data; } public static InstanceQuery getQuery() { InstanceQuery query = null; try { query = new InstanceQuery(); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } query.setDatabaseURL("jdbc:oracle:thin:@192.168.2.67:1521:orcl"); query.setUsername("..."); query.setPassword("..."); return query; } private static double computeDistance(Instance in1, Instance in2) { double dis = 0.0; // if ( in1.toString(0).equals(in2.toString(0)) ) { // dis += 1; // } // if ( in1.toString(2).equals(in2.toString(2)) ) { // dis += 5; // } for (int k=3; k<in1.numAttributes(); k++) { double dicrim = in1.value(k) - in2.value(k); dis += dicrim*dicrim; } return dis; } private static boolean equalsInstance(Instance in1, Instance in2) { return in1.equals(in2); } private static boolean equalsInstance(Instances ins1, Instances ins2) { boolean flag = true; if (ins1.numInstances() != ins2.numInstances()) return false; for (int i=0; i<ins1.numInstances(); i++) { if (computeDistance(ins1.instance(i), ins2.instance(i)) >= 0.0001) { flag = false; break; } } return flag; } public static void printInstances(Instances instances) { for (int i=0; i<instances.numInstances(); i++) { System.out.println(instances.instance(i)); } } }

     

    四、其他

    1、离散变量:数据中有个有用的离散变量,不方便直接放到kmeans算法中。但一定要用,也是可以的。

    2、DatabaseUtils.props文件,在jar包中不方便直接修改,而我用解压工具修改在压缩回去,虽然可行,总觉得不太好。不知道有什么更好的方法?

  • 相关阅读:
    Ros学习——Cmakelists.txt文件解读
    Ros学习——Movebase源码解读
    C++——STL之vector, list, deque容器对比与常用函数
    Ros学习——移动机器人Ros导航详解及源码解析
    C++——多线程
    C++——STL容器
    PHP对图片按照一定比例缩放并生成图片文件
    PHP二维数组排序
    PHP裁剪图片并上传完整demo
    [PHP] php实现文件下载
  • 原文地址:https://www.cnblogs.com/549294286/p/3461994.html
Copyright © 2011-2022 走看看