zoukankan      html  css  js  c++  java
  • 聚类结果的评估指标及其JAVA实现

    一. 前言

    又GET了一项技能。在做聚类算法的时候,由于要评估所提出的聚类算法的好坏,于是需要与一些已知的算法对比,或者用一些人工标注的标签来比较,于是用到了聚类结果的评估指标。我了解了以下几项。


    首先定义几个量:(借鉴该博客:http://blog.csdn.net/luoleicn/article/details/5350378

    TP:是指被聚在一类的两个量被正确的分类了(即在标准标注里属于一类的两个对象被聚在一类)

    TN:是指不应该被聚在一类的两个对象被正确地分开了(即在标准标注里不是一类的两个对象在待测结果也没聚在一类)

    FP:指不应该放在一类的对象被错误的放在了一类。(即在标准标注里不是一类,但在待测结果里聚在一类)

    FN:指不应该分开的对象被错误的分开了。(即在标准标注里是一类,但在待测结果里没聚在一类)

    P = TP + FP

    N = TN + FN

    1.准确率、识别率:(rank Index)  RI

    accuracy = (TP + TN)/(P + N)


    2.错误率、误分类率

    error rate = (FP + FN)/(P + N)


    3.敏感度


    sensitivity = TP / P


    4.特效性


    specificity = TN / N


    5.精度


    precision = TP  /   (TP + FP)


    6.召回率


    recall  =  TP  /   (TP  + FN)


    7.RI  其实就是  1  的 accuracy


    8.F度量

    P为precision

    R为recall


    9.NMI(normalized mutual information)



    10 Jaccard

    J = TP  / (TP + F)



    二、JAVA实现(未优化)

    其中很多重复代码,还没有优化。。。


    package others;
    
    import java.util.HashMap;
    import java.util.HashSet;
    import java.util.Iterator;
    import java.util.Map;
    import java.util.Map.Entry;
    import java.util.Set;
    
    import javax.rmi.CORBA.Util;
    
    import org.graphstream.algorithm.measure.NormalizedMutualInformation;
    
    
    /*function:常用的聚类评价指标有purity, precision, recall,  RI 和 F-score,jaccard
     * @param:
     * @author:Wenbao Li
     * @Data:2015-07-13
     */
    public class ClusterEvaluation {
    
    	public static void main(String[] args){
    		int[] A = {1,3,3,3,3,3,3,2,1,0,2,0,2,0,2,1,1,0,1,1};
    		int[] B = {2,2,0,0,0,3,2,2,3,1,3,1,0,1,2,1,0,1,3,3};
    		double purity = Purity(A,B);
    		System.out.println("purity		"+purity);
    		System.out.println("Pre		"+Precision(A,B));
    		System.out.println("Recall		"+Recall(A,B));
    		System.out.println("RI(Accuracy)		"+RI(A,B));
    		System.out.println("Fvalue		"+F_score(A,B));
    		System.out.println("NMI		"+NMI(A,B));
    		
    	}
    	/*
    	 * 计算一个聚类结果的簇的个数,以及每一簇中的对象个数,
    	 */
    	public static Map<Integer,Set<Integer>> clusterDistri(int[] A){
    		Map<Integer,Set<Integer>> clusterD = new HashMap<Integer,Set<Integer>>();
    		int max = -1;
    		for(int i = 0;i< A.length;i++){
    			
    			if(max < A[i]){
    				max = A[i];
    			}
    		}
    		for(int i = 0;i< A.length;i++){
    			int temp = A[i];	
    			if(temp < max+1){
    				if(clusterD.containsKey(temp)){
    					Set<Integer> set = clusterD.get(temp);
    					set.add(i+1);
    					clusterD.put(temp, set);
    				}else{
    					Set<Integer> set = new HashSet<Integer>();
    					set.add(i+1);
    					clusterD.put(temp, set);
    				}
    			}
    		}
    		return clusterD;
    	}
    	public static double ClusEvaluate(String method,int[] A,int[] B){
    		
    		switch(method){
    		case "Purity":
    			return Purity(A,B);
    		case "Precision":
    			return Precision(A,B);
    		case "Recall":
    			return Recall(A,B);
    		case "RI":
    			return RI(A,B);
    		case "F_score":
    			return F_score(A,B);
    		case "NMI":
    			return NMI(A,B);
    		case "Jaccard":
    			return Jaccard(A,B);
    		default:
    			return -1.0;
    		}
    		
    	}
    	public static int[] commNum(Map<Integer,Set<Integer>> A,Map<Integer,Set<Integer>> B){
    		int[] commonNo = new int[A.size()];
    		int com = 0;
    		Iterator<Map.Entry<Integer,Set<Integer>>> itA = A.entrySet().iterator();
    		int i = 0;
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			Set<Integer> setA = entryA.getValue();
    			Iterator<Map.Entry<Integer,Set<Integer>>> itB = B.entrySet().iterator();
    			int maxComm = -1;
    			while(itB.hasNext()){
    				Entry<Integer,Set<Integer>> entryB = itB.next();
    				Set<Integer> setB = entryB.getValue();
    				int lengthA = setA.size();
    				Set<Integer> temp = new HashSet<Integer>(setA);
    				
    				temp.removeAll(setB);
    				
    				int lengthCom = lengthA - temp.size();
    				
    				if(maxComm < lengthCom){
    					maxComm = lengthCom;
    				}
    				
    			}
    			
    			commonNo[i] = maxComm;
    			com = com + maxComm;
    			i++;
    		}
    		
    		return commonNo;
    	}
    	/*
    	 * 所有簇分配正确的除以总的。其中B是对比的标准标签。
    	 */
    	public static double Purity(int[] A,int[] B){
    		double value;
    		Map<Integer,Set<Integer>> clusterA = clusterDistri(A);
    		Map<Integer,Set<Integer>> clusterB = clusterDistri(B);
    		int[] commonNo = commNum(clusterA,clusterB);
    		int com = 0;
    		for(int i = 0;i<commonNo.length;i++){
    			com = com + commonNo[i];
    		}
    		value = com*1.0/A.length;
    		
    		return value;
    	}
    	/*
    	 * @param A,B
    	 * @return 精度
    	 */
    	public static double Precision(int[] A,int[] B){
    		double value = 0.0;
    		Map<Integer,Set<Integer>> clusterA = clusterDistri(A);//得到聚类结果A的类分布
    		Map<Integer,Set<Integer>> clusterB = clusterDistri(B);//得到聚类B(标准)的类分布
    		int[] commonNo = commNum(clusterA,clusterB);//得到A中每个簇中聚类正确的数目。
    		int allP = 0;
    		int TP = 0;
    		int FP = 0;
    		int TN = 0;
    		int FN = 0;
    		Iterator<Map.Entry<Integer,Set<Integer>>> itA = clusterA.entrySet().iterator();
    		int i = 0;
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			allP = allP + combination(entryA.getValue().size(),2);
    			TP = TP + combination(commonNo[i],2);
    			i++;
    		}
    		
    		FP = allP - TP;
    		
    		itA = clusterA.entrySet().iterator();
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			
    			Iterator<Map.Entry<Integer,Set<Integer>>> itA2 = clusterA.entrySet().iterator();
    			while(itA2.hasNext()){
    				Entry<Integer,Set<Integer>> entryA2 = itA2.next();
    				if(entryA != entryA2){
    					Set<Integer> s1 = entryA.getValue();
    					Set<Integer> s2 = entryA2.getValue();
    					for(Integer i1 :s1){
    						for(Integer i2:s2){
    							if(B[i1-1] != B[i2-1]){
    								TN++;
    							}else{
    								FN++;
    							}
    						}
    					}
    					
    				}
    			}
    		}
    		
    		double P = TP*1.0/(TP + FP);
    		return P;
    	}
    	/*
    	 * @param A,B
    	 * @return recal召回率
    	 */
    	public static double Recall(int[] A,int[] B){
    		double value = 0.0;
    		Map<Integer,Set<Integer>> clusterA = clusterDistri(A);//得到聚类结果A的类分布
    		Map<Integer,Set<Integer>> clusterB = clusterDistri(B);//得到聚类B(标准)的类分布
    		int[] commonNo = commNum(clusterA,clusterB);//得到A中每个簇中聚类正确的数目。
    		int allP = 0;
    		int TP = 0;
    		int FP = 0;
    		int TN = 0;
    		int FN = 0;
    		Iterator<Map.Entry<Integer,Set<Integer>>> itA = clusterA.entrySet().iterator();
    		int i = 0;
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			allP = allP + combination(entryA.getValue().size(),2);
    			TP = TP + combination(commonNo[i],2);
    			i++;
    		}
    		
    		FP = allP - TP;
    		
    		itA = clusterA.entrySet().iterator();
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			
    			Iterator<Map.Entry<Integer,Set<Integer>>> itA2 = clusterA.entrySet().iterator();
    			while(itA2.hasNext()){
    				Entry<Integer,Set<Integer>> entryA2 = itA2.next();
    				if(entryA != entryA2){
    					Set<Integer> s1 = entryA.getValue();
    					Set<Integer> s2 = entryA2.getValue();
    					for(Integer i1 :s1){
    						for(Integer i2:s2){
    							if(B[i1-1] != B[i2-1]){
    								TN++;
    							}else{
    								FN++;
    							}
    						}
    					}
    					
    				}
    			}
    		}
    		
    
    		double R = TP * 1.0/(TP + FN);
    		return R;
    	}
    	/*
    	 * @param A,B
    	 * @return RankIndex
    	 */
    	public static double RI(int[] A,int[] B){
    		
    		double value = 0.0;
    		Map<Integer,Set<Integer>> clusterA = clusterDistri(A);//得到聚类结果A的类分布
    		Map<Integer,Set<Integer>> clusterB = clusterDistri(B);//得到聚类B(标准)的类分布
    		int[] commonNo = commNum(clusterA,clusterB);//得到A中每个簇中聚类正确的数目。
    		int P = 0;
    		int TP = 0;
    		int FP = 0;
    		int TN = 0;
    		int FN = 0;
    		Iterator<Map.Entry<Integer,Set<Integer>>> itA = clusterA.entrySet().iterator();
    		int i = 0;
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			P = P + combination(entryA.getValue().size(),2);
    			TP = TP + combination(commonNo[i],2);
    			i++;
    		}
    		
    		FP = P - TP;
    		
    		itA = clusterA.entrySet().iterator();
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			
    			Iterator<Map.Entry<Integer,Set<Integer>>> itA2 = clusterA.entrySet().iterator();
    			while(itA2.hasNext()){
    				Entry<Integer,Set<Integer>> entryA2 = itA2.next();
    				if(entryA != entryA2){
    					Set<Integer> s1 = entryA.getValue();
    					Set<Integer> s2 = entryA2.getValue();
    					for(Integer i1 :s1){
    						for(Integer i2:s2){
    							if(B[i1-1] != B[i2-1]){
    								TN++;
    							}else{
    								FN++;
    							}
    						}
    					}
    					
    				}
    			}
    		}
    		value = (TP + TN)*1.0/(TP + FP + FN + TN);
    		
    		return value;
    	}
    	
    	/*
    	 * F值,是对精度和召回率的平衡,
    	 * @param A:评估对象。B:评估标准;beta:均衡参数
    	 * @return F值
    	 */
    	public static double F_score(int[] A,int[] B){
    
    		double beta = 1.0;
    		double value = 0.0;
    		Map<Integer,Set<Integer>> clusterA = clusterDistri(A);//得到聚类结果A的类分布
    		Map<Integer,Set<Integer>> clusterB = clusterDistri(B);//得到聚类B(标准)的类分布
    		int[] commonNo = commNum(clusterA,clusterB);//得到A中每个簇中聚类正确的数目。
    		int allP = 0;
    		int TP = 0;
    		int FP = 0;
    		int TN = 0;
    		int FN = 0;
    		Iterator<Map.Entry<Integer,Set<Integer>>> itA = clusterA.entrySet().iterator();
    		int i = 0;
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			allP = allP + combination(entryA.getValue().size(),2);
    			TP = TP + combination(commonNo[i],2);
    			i++;
    		}
    		
    		FP = allP - TP;
    		
    		itA = clusterA.entrySet().iterator();
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			
    			Iterator<Map.Entry<Integer,Set<Integer>>> itA2 = clusterA.entrySet().iterator();
    			while(itA2.hasNext()){
    				Entry<Integer,Set<Integer>> entryA2 = itA2.next();
    				if(entryA != entryA2){
    					Set<Integer> s1 = entryA.getValue();
    					Set<Integer> s2 = entryA2.getValue();
    					for(Integer i1 :s1){
    						for(Integer i2:s2){
    							if(B[i1-1] != B[i2-1]){
    								TN++;
    							}else{
    								FN++;
    							}
    						}
    					}
    					
    				}
    			}
    		}
    		
    		double P = TP*1.0/(TP + FP);
    		double R = TP * 1.0/(TP + FN);
    		value = (beta*beta + 1)*P * R/(beta*beta*P + R);
    		return value;
    	}
    	
    	public static double Jaccard(int[] A,int[] B){
    
    		double value = 0.0;
    		Map<Integer,Set<Integer>> clusterA = clusterDistri(A);//得到聚类结果A的类分布
    		Map<Integer,Set<Integer>> clusterB = clusterDistri(B);//得到聚类B(标准)的类分布
    		int[] commonNo = commNum(clusterA,clusterB);//得到A中每个簇中聚类正确的数目。
    		int allP = 0;
    		int TP = 0;
    		int FP = 0;
    		int TN = 0;
    		int FN = 0;
    		Iterator<Map.Entry<Integer,Set<Integer>>> itA = clusterA.entrySet().iterator();
    		int i = 0;
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			allP = allP + combination(entryA.getValue().size(),2);
    			TP = TP + combination(commonNo[i],2);
    			i++;
    		}
    		
    		FP = allP - TP;
    		
    		itA = clusterA.entrySet().iterator();
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			
    			Iterator<Map.Entry<Integer,Set<Integer>>> itA2 = clusterA.entrySet().iterator();
    			while(itA2.hasNext()){
    				Entry<Integer,Set<Integer>> entryA2 = itA2.next();
    				if(entryA != entryA2){
    					Set<Integer> s1 = entryA.getValue();
    					Set<Integer> s2 = entryA2.getValue();
    					for(Integer i1 :s1){
    						for(Integer i2:s2){
    							if(B[i1-1] != B[i2-1]){
    								TN++;
    							}else{
    								FN++;
    							}
    						}
    					}
    					
    				}
    			}
    		}
    		
    		
    		value = TP * 1.0 / (TP + FP + FN);
    		return value;
    	}
    	public static double NMI(int[] A,int[] B){
    		Map<Integer,Set<Integer>> clusterA = clusterDistri(A);//得到聚类结果A的类分布
    		Map<Integer,Set<Integer>> clusterB = clusterDistri(B);//得到聚类B(标准)的类分布
    		Iterator<Map.Entry<Integer,Set<Integer>>> itA = clusterA.entrySet().iterator();
    		
    		Iterator<Map.Entry<Integer,Set<Integer>>> itB = clusterB.entrySet().iterator();
    		
    		Set<Set<Integer>> partitionF = new HashSet<Set<Integer>>();
    		Set<Set<Integer>> partitionR = new HashSet<Set<Integer>>();
    		int nodeCount = B.length;
    		while(itA.hasNext()){
    			Entry<Integer,Set<Integer>> entryA = itA.next();
    			Set<Integer> setA = entryA.getValue();
    			partitionF.add(setA);
    			setA = null;
    			entryA = null;
    		}
    
    		
    		while(itB.hasNext()){
    			Entry<Integer,Set<Integer>> entryB = itB.next();
    			Set<Integer> setB = entryB.getValue();
    			partitionR.add(setB);
    			setB = null;
    			entryB = null;
    		}
    		return computeNMI(partitionF,partitionR,nodeCount);
    	}
    	public static double computeNMI(Set<Set<Integer>> partitionF,
    			Set<Set<Integer>> partitionR,int nodeCount) {
    		int[][] XY = new int[partitionR.size()][partitionF.size()];
    		int[] X = new int[partitionR.size()];
    		int[] Y = new int[partitionF.size()];
    		int i = 0;
    		int j = 0;
    		
    		for (Set<Integer> com1 : partitionR) {
    			j = 0;
    			
    			for (Set<Integer> com2 : partitionF) {
    				
    				XY[i][j] = intersect(com1, com2);//待测结果第i个簇和标准结果第j个簇的共有元素个数
    				X[i] += XY[i][j];//待测结果第i个簇与所有标准结果簇的公共元素个数(感觉就是第i个簇的元素个数)
    				Y[j] += XY[i][j];//标准结果簇第j个簇的元素个数()
    		
    				j++;
    			}
    			i++;
    		}
    		int N = nodeCount;
    		double Ixy = 0;
    		double Ixy2 = 0;
    		for (i = 0; i < partitionR.size(); i++) {
    			for (j = 0; j < partitionF.size(); j++) {
    				if (XY[i][j] > 0) {
    					Ixy += ((double) XY[i][j] / N)
    							* (Math.log((double) XY[i][j] * N / (X[i] * Y[j])) / Math
    									.log(2.0));
    //					Ixy2 = (float) (Ixy2 + -2.0D * XY[i][j]
    //							* Math.log(XY[i][j] * N / X[i] * Y[j]));
    				}
    			}
    		}
    //		System.out.println(Ixy2);
    //		double denom = 0.0F;
    //		for (int ii = 0; ii < X.length; ++ii)
    //			denom = (double) (denom + X[ii] * Math.log(X[ii] / N));
    //		for (int jj = 0; jj < Y.length; ++jj) {
    //			denom = (double) (denom + Y[jj] * Math.log(Y[jj] / N));
    //		}
    //
    //		System.out.println(denom);
    //		double M = (Ixy / denom);
    //		
    //		return M;
    		
    		double Hx = 0;
    		double Hy = 0;
    		for (i = 0; i < partitionR.size(); i++) {
    			if (X[i] > 0)
    				Hx += h((double) X[i] / N);
    		}
    		for (j = 0; j < partitionF.size(); j++) {
    			if (Y[j] > 0)
    				Hy += h((double) Y[j] / N);
    		}
    		
    		double InormXY = Ixy / Math.sqrt(Hx * Hy);
    		return InormXY;
    	}
    	private static double h(double p) {
    		return -p * (Math.log(p) / Math.log(2.0));
    	}
    	/*
    	 * 两个集合的公共元素个数
    	 */
    	private static int intersect(Set<Integer> com1, Set<Integer> com2) {
    		int num = 0;
    		for (Integer v1 : com1) {
    			if (com2.contains(v1))
    				num++;
    		}
    		return num;
    	}
    	/*
    	 * C(m,n)=m取n
    	 */
    	public static int combination(int m,int n){
    		int result = 1;
    		if(m < n){
    			return -1;
    		}
    		result = factorial(m)/(factorial(n)*factorial(m-n));
    		
    		return result;
    	}
    	
    	public static int factorial(int m){
    		
    		if((m == 1) || (m == 0)){
    			return 1;
    		}else if(m < 0){
    			return -1;
    		}else{
    			return m*factorial(m-1);
    		}
    	}
    }
    




  • 相关阅读:
    MVC新手指南
    BufferedReader方法-----Scanner方法
    sin=in.readLine();
    STL:string 大小(Size)和容量(Capacity)
    2014=9=24 连接数据库2
    2014=9=24 连接数据库1
    常用英语单词
    Linux权限详解(chmod、600、644、666、700、711、755、777、4755、6755、7755)
    linux 常用快捷键
    启动sh文件注意的问题
  • 原文地址:https://www.cnblogs.com/wenbaoli/p/5655742.html
Copyright © 2011-2022 走看看