zoukankan      html  css  js  c++  java
  • 转:谱聚类

    谱聚类

     

    广义上来说,任何在算法中用到SVD/特征值分解的,都叫Spectral Algorithm。顺便说一下,对于任意矩阵只存在奇异值分解,不存在特征值分解。对于正定的对称矩阵,奇异值就是特征值,奇异向量就是特征向量。

    传统的聚类算法,如K-Means、EM算法都是建立在凸球形样本空间上,当样本空间不为凸时,算法会陷入局部最优,最终结果受初始参数的选择影响比较大。而谱聚类可以在任意形状的样本空间上聚类,且收敛于全局最优解。

    谱聚类和CHAMELEON聚类很像,都是把样本点的相似度放到一个带权无向图中,采用“图划分”的方法进行聚类。只是谱聚类算法在进行图划分的时候发现计算量很大,转而求特征值去了,而且最后还在几个小特征向量组成的矩阵上进行了K-Means聚类

    Simply speaking,谱聚类算法分为3步:

    1. 构造一个N×N的权值矩阵W,Wij表示样本i和样本j的相似度,显然W是个对称矩阵。相似度的计算方法很多了,你可以用欧拉距离、街区距离、向量夹角、皮尔森相关系数等。并不是任意两个点间的相似度都要表示在图上,我们希望的权值图是比较稀疏的,有2种方法:权值小于阈值的认为是0;K最邻近方法,即每个点只和跟它最近的k个点连起来,CHAMELEON算法的第1阶段就是这么干的。再构造一个对角矩阵D,Dii为W第i列元素之和。最后构造矩阵L=D-W。可以证明L是个半正定和对称矩阵。
    2. 求L的前K小特征值对应的特征向量(这要用到奇异值分解了)。把K个特征向量放在一起构造一个N×K的矩阵M。
    3. 把M的每一行当成一个新的样本点,对这N个新的样本点进行K-Means聚类。

    从文件读入样本点,最终算得矩阵L

    #include<math.h>
    #include<string.h>
    #include"matrix.h"
    #include"svd.h"
     
    #define N 19        //样本点个数
    #define K 4         //K-Means算法中的K
    #define T 0.1       //样本点之间相似度的阈值
     
    double sample[N][2];    //存放所有样本点的坐标(2维的)
     
    void readSample(char *filename){
        FILE *fp;
        if((fp=fopen(filename,"r"))==NULL){
            perror("fopen");
            exit(0);
        }
        char buf[50]={0};
        int i=0;
        while(fgets(buf,sizeof(buf),fp)!=NULL){
            char *w=strtok(buf," ");
            double x=atof(w);
            w=strtok(NULL," ");
            double y=atof(w);
            sample[i][0]=x;
            sample[i][1]=y;
            i++;
            memset(buf,0x00,sizeof(buf));
        }
        assert(i==N);
        fclose(fp);
    }
     
    double** getSimMatrix(){
        //为二维矩阵申请空间
        double **matrix=getMatrix(N,N);
        //计算样本点两两之间的相似度,得到矩阵W
        int i,j;
        for(i=0;i<N;i++){
            matrix[i][i]=1;
            for(j=i+1;j<N;j++){
                double dist=sqrt(pow(sample[i][0]-sample[j][0],2)+pow(sample[i][1]-sample[j][1],2));
                double sim=1.0/(1+dist);
                if(sim>T){
                    matrix[j][i]=sim;
                    matrix[i][j]=sim;
                }
            }
        }
        //计算L=D-W
        for(j=0;j<N;j++){
            double sum=0;
            for(i=0;i<N;i++){
                sum+=matrix[i][j];
                if(i!=j)
                    matrix[i][j]=0-matrix[i][j];
            }
            matrix[j][j]=matrix[j][j]-sum;
        }
        return matrix;
    }
     
    int main(){
        char *file="/home/orisun/data";
        readSample(file);
        double **L=getSimMatrix();
        printMatrix(L,N,N);
         
        double **M=singleVector(L,N,N,5);
        printMatrix(M,N,5);
         
        freeMatrix(L,N);
     
        return 0;
    }

    L已是对称矩阵,直接奇异值分解的得到的就是特征向量

    最后是运行KMeans的Java代码

    package ai;
     
    public class Global {
        //计算两个向量的欧氏距离
        public static double calEuraDist(double[] arr1,double[] arr2,int len){
            double result=0.0;
            for(int i=0;i<len;i++){
                result+=Math.pow(arr1[i]-arr2[i],2.0);
            }
            return Math.sqrt(result);
        }
    }
    package ai;
     
    public class DataObject {
     
        String docname;
        double[] vector;
        int cid;   
        boolean visited;
         
        public DataObject(int len){
            vector=new double[len];
        }
     
        public String getName() {
            return docname;
        }
     
        public void setName(String docname) {
            this.docname = docname;
        }
     
        public double[] getVector() {
            return vector;
        }
     
        public void setVector(double[] vector) {
            this.vector = vector;
        }
     
        public int getCid() {
            return cid;
        }
     
        public void setCid(int cid) {
            this.cid = cid;
        }
     
        public boolean isVisited() {
            return visited;
        }
     
        public void setVisited(boolean visited) {
            this.visited = visited;
        }
     
    }
    package ai;
     
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import java.util.ArrayList;
    import java.util.Iterator;
    public class DataSource {
     
        ArrayList<DataObject> objects;
        int row;
        int col;
     
        public void readMatrix(File dataFile) {
            try {
                FileReader fr = new FileReader(dataFile);
                BufferedReader br = new BufferedReader(fr);
                String line = br.readLine();
                String[] words = line.split("\s+");
                row = Integer.parseInt(words[0]);
                // row=1000;
                col = Integer.parseInt(words[1]);
                objects = new ArrayList<DataObject>(row);
                for (int i = 0; i < row; i++) {
                    DataObject object = new DataObject(col);
                    line = br.readLine();
                    words = line.split("\s+");
                    for (int j = 0; j < col; j++) {
                        object.getVector()[j] = Double.parseDouble(words[j]);
                    }
                    objects.add(object);
                }
                br.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
     
        public void readRLabel(File file) {
            try {
                FileReader fr = new FileReader(file);
                BufferedReader br = new BufferedReader(fr);
                String line = null;
                for (int i = 0; i < row; i++) {
                    line = br.readLine();
                    objects.get(i).setName(line.trim());
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
     
        public void printResult(ArrayList<DataObject> objects, int n) {
            //DBScan是从第1类开始,K-Means是从第0类开始
    //      for (int i =0; i <n; i++) {
            for(int i=1;i<=n;i++){
                System.out.println("=============属于第"+i+"类的有:===========================");
                Iterator<DataObject> iter = objects.iterator();
                while (iter.hasNext()) {
                    DataObject object = iter.next();
                    int cid=object.getCid();
                    if(cid==i){
                        System.out.println(object.getName());
    //                  switch(Integer.parseInt(object.getName())/1000){
    //                  case 0:
    //                      System.out.println(0);
    //                      break;
    //                  case 1:
    //                      System.out.println(1);
    //                      break;
    //                  case 2:
    //                      System.out.println(2);
    //                      break;
    //                  case 3:
    //                      System.out.println(3);
    //                      break;
    //                  case 4:
    //                      System.out.println(4);
    //                      break;
    //                  case 5:
    //                      System.out.println(5);
    //                      break;
    //                  default:
    //                      System.out.println("Go Out");
    //                      break;
    //                  }              
                    }
                }
            }
        }
    }
    package ai;
     
    import java.io.File;
    import java.util.ArrayList;
    import java.util.Iterator;
    import java.util.Random;
      
    public class KMeans {
      
        int k; // 指定划分的簇数
        double mu; // 迭代终止条件,当各个新质心相对于老质心偏移量小于mu时终止迭代
        double[][] center; // 上一次各簇质心的位置
        int repeat; // 重复运行次数
        double[] crita; // 存放每次运行的满意度
      
        public KMeans(int k, double mu, int repeat, int len) {
            this.k = k;
            this.mu = mu;
            this.repeat = repeat;
            center = new double[k][];
            for (int i = 0; i < k; i++)
                center[i] = new double[len];
            crita = new double[repeat];
        }
      
        // 初始化k个质心,每个质心是len维的向量,每维均在left--right之间
        public void initCenter(int len, ArrayList<DataObject> objects) {
            Random random = new Random(System.currentTimeMillis());
            int[] count = new int[k]; // 记录每个簇有多少个元素
            Iterator<DataObject> iter = objects.iterator();
            while (iter.hasNext()) {
                DataObject object = iter.next();
                int id = random.nextInt(10000)%k;
                count[id]++;
                for (int i = 0; i < len; i++)
                    center[id][i] += object.getVector()[i];
            }
            for (int i = 0; i < k; i++) {
                for (int j = 0; j < len; j++) {
                    center[i][j] /= count[i];
                }
            }
        }
      
        // 把数据集中的每个点归到离它最近的那个质心
        public void classify(ArrayList<DataObject> objects) {
            Iterator<DataObject> iter = objects.iterator();
            while (iter.hasNext()) {
                DataObject object = iter.next();
                double[] vector = object.getVector();
                int len = vector.length;
                int index = 0;
                double neardist = Double.MAX_VALUE;
                for (int i = 0; i < k; i++) {
                    double dist = Global.calEuraDist(vector, center[i], len); // 使用欧氏距离
                    if (dist < neardist) {
                        neardist = dist;
                        index = i;
                    }
                }
                object.setCid(index);
            }
        }
      
        // 重新计算每个簇的质心,并判断终止条件是否满足,如果不满足更新各簇的质心,如果满足就返回true.len是数据的维数
        public boolean calNewCenter(ArrayList<DataObject> objects, int len) {
            boolean end = true;
            int[] count = new int[k]; // 记录每个簇有多少个元素
            double[][] sum = new double[k][];
            for (int i = 0; i < k; i++)
                sum[i] = new double[len];
            Iterator<DataObject> iter = objects.iterator();
            while (iter.hasNext()) {
                DataObject object = iter.next();
                int id = object.getCid();
                count[id]++;
                for (int i = 0; i < len; i++)
                    sum[id][i] += object.getVector()[i];
            }
            for (int i = 0; i < k; i++) {
                if (count[i] != 0) {
                    for (int j = 0; j < len; j++) {
                        sum[i][j] /= count[i];
                    }
                }
                // 簇中不包含任何点,及时调整质心
                else {
                    int a=(i+1)%k;
                    int b=(i+3)%k;
                    int c=(i+5)%k;
                    for (int j = 0; j < len; j++) {
                        center[i][j] = (center[a][j]+center[b][j]+center[c][j])/3;
                    }
                }
            }
            for (int i = 0; i < k; i++) {
                // 只要有一个质心需要移动的距离超过了mu,就返回false
                if (Global.calEuraDist(sum[i], center[i], len) >= mu) {
                    end = false;
                    break;
                }
            }
            if (!end) {
                for (int i = 0; i < k; i++) {
                    for (int j = 0; j < len; j++)
                        center[i][j] = sum[i][j];
                }
            }
            return end;
        }
      
        // 计算各簇内数据和方差的加权平均,得出本次聚类的满意度.len是数据的维数
        public double getSati(ArrayList<DataObject> objects, int len) {
            double satisfy = 0.0;
            int[] count = new int[k];
            double[] ss = new double[k];
            Iterator<DataObject> iter = objects.iterator();
            while (iter.hasNext()) {
                DataObject object = iter.next();
                int id = object.getCid();
                count[id]++;
                for (int i = 0; i < len; i++)
                    ss[id] += Math.pow(object.getVector()[i] - center[id][i], 2.0);
            }
            for (int i = 0; i < k; i++) {
                satisfy += count[i] * ss[i];
            }
            return satisfy;
        }
      
        public double run(int round, DataSource datasource, int len) {
            System.out.println("第" + round + "次运行");
            initCenter(len,datasource.objects);
            classify(datasource.objects);
            while (!calNewCenter(datasource.objects, len)) {
                classify(datasource.objects);
            }
            datasource.printResult(datasource.objects, k);
            double ss = getSati(datasource.objects, len);
            System.out.println("加权方差:" + ss);
            return ss;
        }
      
        public static void main(String[] args) {
            DataSource datasource = new DataSource();
            datasource.readMatrix(new File("/home/orisun/test/dot.mat"));
            datasource.readRLabel(new File("/home/orisun/test/dot.rlabel"));
            int len = datasource.col;
            // 划分为4个簇,质心移动小于1E-8时终止迭代,重复运行7次
            KMeans km = new KMeans(4, 1E-10, 7, len);
            int index = 0;
            double minsa = Double.MAX_VALUE;
            for (int i = 0; i < km.repeat; i++) {
                double ss = km.run(i, datasource, len);
                if (ss < minsa) {
                    minsa = ss;
                    index = i;
                }
            }
            System.out.println("最好的结果是第" + index + "次。");
        }
    }
    原文来自:博客园(华夏35度)http://www.cnblogs.com/zhangchaoyang 作者:Orisun
  • 相关阅读:
    第二章、开发环境部署
    第一章、数据分析介绍
    爬虫之Beautiful Soup
    爬虫之selenium
    使用Python连接Mongodb,对数据库进行操作
    Python练习实例002
    Python练习实例001
    Python入门练手100例
    Python
    剑指Offer-003:从尾到头打印链表
  • 原文地址:https://www.cnblogs.com/lm3306/p/9314032.html
Copyright © 2011-2022 走看看