zoukankan      html  css  js  c++  java
  • JAVA实现聚类指标的计算Purity、NMI、RI、Precision、Recall、F值。

    第一个:计算NMI的:

    package clusters;

    import java.io.*;
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;

    /**
    * DATE: 16-6-18 TIME: 上午10:00
    */

    /**
    * 参考文献:http://www-nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html
    */
    public class NormalizedMutualInformation {
    public static String path = "/home/fhqplzj/IdeaProjects/Vein/src/main/resources/nmi_data";

    public static void loadData(List<List<Integer>> lists) {
    try {
    BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(path)));
    String line;
    while ((line = bufferedReader.readLine()) != null) {
    String[] data = line.split("\s+");
    ArrayList<Integer> integers = new ArrayList<>();
    for (String s : data) {
    integers.add(Integer.parseInt(s));
    }
    lists.add(integers);
    }
    bufferedReader.close();
    } catch (FileNotFoundException e) {
    e.printStackTrace();
    } catch (IOException e) {
    e.printStackTrace();
    }
    }

    public static void main(String[] args) {
    List<List<Integer>> lists = new ArrayList<>();
    loadData(lists);
    int K = lists.size();
    int N = 0;
    int[] clusters = new int[K];
    for (int i = 0; i < K; i++) {
    clusters[i] = lists.get(i).size();
    N += clusters[i];
    }
    Map<Integer, Integer> map = new HashMap<>();
    for (List<Integer> list : lists) {
    for (Integer integer : list) {
    map.put(integer, map.getOrDefault(integer, 0) + 1);
    }
    }
    double clusterEntropy = 0;
    for (int cluster : clusters) {
    double tmp = 1.0 * cluster / N;
    clusterEntropy -= (tmp * (Math.log(tmp) / Math.log(2)));
    }
    // System.out.println("clusterEntropy = " + clusterEntropy);
    double classEntropy = 0;
    for (Integer integer : map.values()) {
    double tmp = 1.0 * integer / N;
    classEntropy -= (tmp * (Math.log(tmp) / Math.log(2)));
    }
    // System.out.println("classEntropy = " + classEntropy);
    double totalEntropy = 0;
    Map<Integer, Integer> tmpMap = new HashMap<>();
    for (int i = 0; i < K; i++) {
    int wk = clusters[i];
    tmpMap.clear();
    for (Integer integer : lists.get(i)) {
    tmpMap.put(integer, tmpMap.getOrDefault(integer, 0) + 1);
    }
    for (Map.Entry<Integer, Integer> entry : tmpMap.entrySet()) {
    int cj = map.get(entry.getKey());
    int value = entry.getValue();
    totalEntropy += (1.0 * value / N * (Math.log(1.0 * N * value / (wk * cj)) / Math.log(2)));
    }
    }
    // System.out.println("totalEntropy = " + totalEntropy);
    double nmi = 2 * totalEntropy / (clusterEntropy + classEntropy);
    System.out.println(String.format("nmi = %.2f", nmi));
    }
    }

    //////////////////////////////////////////////

    第二个,一些工具类:

    package clusters;

    import java.util.Arrays;
    import java.util.List;
    import java.util.Map;

    /**
    * DATE: 16-6-18 TIME: 上午11:07
    */
    public class ClusterUtils {
    public static int combination(int n, int k) {
    if (k > n) {
    return 0;
    }
    int[] data = new int[n + 1];
    data[0] = 1;
    for (int i = 0; i < n; i++) {
    for (int j = i + 1; j >= 1; j--) {
    data[j] += data[j - 1];
    }
    }
    return data[k];
    }

    public static int computeTPAndFP(int[] clusters) {
    int result = 0;
    for (int cluster : clusters) {
    result += combination(cluster, 2);
    }
    return result;
    }

    public static int computeFP(List<Map<Integer, Integer>> mapList) {
    int FP = 0;
    for (Map<Integer, Integer> map : mapList) {
    for (Integer integer : map.values()) {
    if (integer >= 2) {
    FP += combination(integer, 2);
    }
    }
    }
    return FP;
    }

    public static int computeOneClass(List<Integer> list) {
    int n = list.size();
    if (n == 0) {
    return 0;
    }
    int result = 0;
    for (int i = 0; i < n - 1; i++) {
    for (int j = i + 1; j < n; j++) {
    result += list.get(i) * list.get(j);
    }
    }
    return result;
    }

    public static int computeFN(List<List<Integer>> lists) {
    int result = 0;
    for (List<Integer> list : lists) {
    result += computeOneClass(list);
    }
    return result;
    }

    public static double computeFValue(double P, double R, double beta) {
    return (beta * beta + 1) * P * R / (beta * beta * P + R);
    }

    public static void main(String[] args) {
    List<Integer> list = Arrays.asList(1, 4, 0);
    System.out.println("computeOneClass(list) = " + computeOneClass(list));
    }
    }

    第三个,计算RI、P、R、F以及Purity的,顺便调用了NMI,一起打印输出,beta取1和5,如stanford文章所述,计算F1和F5

    package clusters;

    import java.util.*;

    /**
    * DATE: 16-6-18 TIME: 上午11:05
    */
    public class RandIndex {
    public static void main(String[] args) {
    List<List<Integer>> lists = new ArrayList<>();
    NormalizedMutualInformation.loadData(lists);
    int K = lists.size();
    int N = 0;
    int[] clusters = new int[K];
    for (int i = 0; i < K; i++) {
    clusters[i] = lists.get(i).size();
    N += clusters[i];
    }
    int TPAndFP = ClusterUtils.computeTPAndFP(clusters);
    List<Map<Integer, Integer>> mapList = new ArrayList<>();
    for (List<Integer> list : lists) {
    Map<Integer, Integer> map = new HashMap<>();
    for (Integer integer : list) {
    map.put(integer, map.getOrDefault(integer, 0) + 1);
    }
    mapList.add(map);
    }
    Set<Integer> set = new HashSet<>();
    for (Map<Integer, Integer> map : mapList) {
    set.addAll(map.keySet());
    }
    int FP = ClusterUtils.computeFP(mapList);
    int TP = TPAndFP - FP;
    List<List<Integer>> lists1 = new ArrayList<>();
    for (Integer integer : set) {
    List<Integer> list = new ArrayList<>();
    for (Map<Integer, Integer> map : mapList) {
    if (map.containsKey(integer)) {
    list.add(map.get(integer));
    }
    }
    lists1.add(list);
    }
    int FN = ClusterUtils.computeFN(lists1);
    int TN = ClusterUtils.combination(N, 2) - TPAndFP - FN;
    // System.out.println("TP = " + TP);
    // System.out.println("FP = " + FP);
    // System.out.println("FN = " + FN);
    // System.out.println("TN = " + TN);
    double RI = 1.0 * (TP + TN) / (TP + FP + FN + TN);
    /**
    * compute Purity
    */
    int totalMax = 0;
    for (Map<Integer, Integer> map : mapList) {
    totalMax += map.values().stream().reduce(Math::max).get();
    }
    double purity = 1.0 * totalMax / N;
    System.out.println(String.format("purity = %.2f", purity));
    /**
    * println Normalized Mutual Information
    */
    NormalizedMutualInformation.main(null);
    System.out.println(String.format("RI = %.2f", RI));
    /**
    * compute F5
    */
    double P = 1.0 * TP / (TP + FP);
    double R = 1.0 * TP / (TP + FN);
    double beta = 1;
    System.out.println(String.format("P = %.2f", P));
    System.out.printf("R = %.3f ", R);
    System.out.println(String.format("beta = 1, F = %.2f", ClusterUtils.computeFValue(P, R, beta)));
    beta = 5;
    System.out.println(String.format("beta = 5, F = %.3f", ClusterUtils.computeFValue(P, R, beta)));
    }
    }

    输入数据就是stanford文中的3个类簇:

    1 1 1 1 1 2
    1 2 2 2 2 3
    1 1 3 3 3

    本文来自http://blog.csdn.net/asd991936157/article/details/51705958,只为学习

  • 相关阅读:
    :nth-child :nth-type-of用法详解
    hosts修改备份
    微信小程序 报警告的解决办法
    微信小程序 body属性的问题
    关于微信小程序post请求数据的坑
    在做展开功能的时候,字体变多了,字体会变大的bug的解决方案
    关于微信小程序并发数不能超过五个的问题
    单行文本省略号与多行文本省略号的实现
    js数据类型判断
    表格td标签在不添加多余标签的情况下实现文本内容单行显示,多余部分省略号表示的方法
  • 原文地址:https://www.cnblogs.com/altlb/p/8021904.html
Copyright © 2011-2022 走看看