zoukankan      html  css  js  c++  java
  • [Java]数据分析--聚类

    距离度量

    • 需求:计算两点间的欧几里得距离、曼哈顿距离、切比雪夫距离、堪培拉距离
    • 实现:利用commons.math3库相应函数
     1 import org.apache.commons.math3.ml.distance.*;
     2 
     3 public class TestMetrics {
     4     public static void main(String[] args) {
     5         double[] x = {1, 3}, y = {5, 6};
     6         
     7         EuclideanDistance eD = new EuclideanDistance();
     8         System.out.printf("Euclidean distance = %.2f%n", eD.compute(x,y));
     9         
    10         ManhattanDistance mD = new ManhattanDistance();
    11         System.out.printf("Manhattan distance = %.2f%n", mD.compute(x,y));
    12         
    13         ChebyshevDistance cD = new ChebyshevDistance();
    14         System.out.printf("Chebyshev distance = %.2f%n", cD.compute(x,y));
    15         
    16         CanberraDistance caD = new CanberraDistance();
    17         System.out.printf("Canberra distance =  %.2f%n", caD.compute(x,y));
    18     }
    19 }
    View Code

    Euclidean distance = 5.00
    Manhattan distance = 7.00
    Chebyshev distance = 4.00
    Canberra distance = 1.00

    层次聚类

    • 需求:将13个样本点分为3类
    • 实现:m点划分为k类,先令m点的每个点为一类,然后找到中心最近的两个类,用一个新的聚类替换,重复m-k次

    HierachicalClustering.java

     1 import java.util.HashSet;
     2 
     3 public class HierarchicalClustering {
     4     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2}, 
     5         {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
     6     private static final int M = DATA.length;  // number of points
     7     private static final int K = 3;            // number of clusters
     8 
     9     public static void main(String[] args) {
    10         HashSet<Cluster> clusters = load(DATA);
    11         for (int i = 0; i < M - K; i++) {
    12             System.out.printf("%n%2d clusters:%n", M-i-1);
    13             coalesce(clusters);
    14             System.out.println(clusters);
    15         }
    16     }
    17     
    18     private static HashSet<Cluster> load(double[][] data) {
    19         HashSet<Cluster> clusters = new HashSet();
    20         for (double[] datum : DATA) {
    21             clusters.add(new Cluster(datum[0], datum[1]));
    22         }
    23         return clusters;
    24     } 
    25     
    26     private static void coalesce(HashSet<Cluster> clusters) {
    27         Cluster cluster1=null, cluster2=null;
    28         double minDist = Double.POSITIVE_INFINITY;
    29         for (Cluster c1 : clusters) {
    30             for (Cluster c2 : clusters) {
    31                 if (!c1.equals(c2) && Cluster.distance(c1, c2) < minDist) {
    32                     cluster1 = c1;
    33                     cluster2 = c2;
    34                     minDist = Cluster.distance(c1, c2);
    35                 }
    36             }
    37         }
    38         clusters.remove(cluster1);
    39         clusters.remove(cluster2);
    40         clusters.add(Cluster.union(cluster1, cluster2));
    41     }
    42 }
    View Code

    Point.java

     1 public class Point {
     2     private final double x, y;
     3 
     4     public Point(double x, double y) {
     5         this.x = x;
     6         this.y = y;
     7     }
     8 
     9     public double getX() {
    10         return x;
    11     }
    12 
    13     public double getY() {
    14         return y;
    15     }
    16 
    17     @Override
    18     public int hashCode() {
    19         int xhC = new Double(x).hashCode();
    20         int yhC = new Double(y).hashCode();
    21         return (int)(xhC + 79*yhC);
    22     }
    23 
    24     @Override
    25     public boolean equals(Object object) {
    26         if (object == null) {
    27             return false;
    28         } else if (object == this) {
    29             return true;
    30         } else if (!(object instanceof Point)) {
    31             return false;
    32         }
    33         Point that = (Point)object;
    34         return bits(that.x) == bits(this.x) && bits(that.y) == bits(this.y);
    35     }
    36     
    37     private long bits(double d) {
    38         return Double.doubleToLongBits(d);
    39 
    40     }
    41 
    42     @Override
    43     public String toString() {
    44         return String.format("(%.2f,%.2f)", x,y);
    45     }
    46 }
    View Code

    Cluster.java

     1 import java.util.HashSet;
     2 
     3 public class Cluster {
     4     private final HashSet<Point> points;
     5     private Point centroid;
     6 
     7     public Cluster(HashSet points, Point centroid) {
     8         this.points = points;
     9         this.centroid = centroid;
    10     }
    11     
    12     public Cluster(Point point) {
    13         this.points = new HashSet();
    14         this.points.add(point);
    15         this.centroid = point;
    16     }
    17 
    18     public Cluster(double x, double y) {
    19         this(new Point(x,y));
    20     }
    21 
    22     public Point getCentroid() {
    23         return centroid;
    24     }
    25 
    26     public void add(Point point) {
    27         points.add(point);
    28         recomputeCentroid();
    29     }
    30 
    31     public void recomputeCentroid() {
    32         double xSum=0.0, ySum=0.0;
    33         for (Point point : points) {
    34             xSum += point.getX();
    35             ySum += point.getY();
    36         }
    37         centroid = new Point(xSum/points.size(), ySum/points.size());
    38     }
    39     
    40     public static double distance(Cluster c1, Cluster c2) {
    41         double dx = c1.centroid.getX() - c2.centroid.getX();
    42         double dy = c1.centroid.getY() - c2.centroid.getY();
    43         return Math.sqrt(dx*dx + dy*dy);
    44     }
    45     
    46     public static Cluster union(Cluster c1, Cluster c2) {
    47         Cluster cluster = new Cluster(c1.points, c1.centroid);
    48         cluster.points.addAll(c2.points);
    49         cluster.recomputeCentroid();
    50         return cluster;
    51     }
    52 
    53     @Override
    54     public int hashCode() {
    55         return points.hashCode();
    56     }
    57 
    58     @Override
    59     public boolean equals(Object object) {
    60         if (object == null) {
    61             return false;
    62         } else if (object == this) {
    63             return true;
    64         } else if (!(object instanceof Cluster)) {
    65             return false;
    66         }
    67         final Cluster that = (Cluster)object;
    68         return that.points.equals(this.points);
    69     }
    70 
    71     @Override
    72     public String toString() {
    73         return String.format("%n{%s,%s}", centroid, points);
    74     }
    75 }
    View Code

    结果-->

      1 12 clusters:
      2 [
      3 {(1.00,1.00),[(1.00,1.00)]}, 
      4 {(1.00,3.00),[(1.00,3.00)]}, 
      5 {(2.00,6.00),[(2.00,6.00)]}, 
      6 {(3.00,2.00),[(3.00,2.00)]}, 
      7 {(4.00,3.00),[(4.00,3.00)]}, 
      8 {(6.00,4.00),[(6.00,4.00)]}, 
      9 {(7.00,1.00),[(7.00,1.00)]}, 
     10 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
     11 {(6.00,3.00),[(6.00,3.00)]}, 
     12 {(3.00,4.00),[(3.00,4.00)]}, 
     13 {(1.00,5.00),[(1.00,5.00)]}, 
     14 {(5.00,6.00),[(5.00,6.00)]}]
     15 
     16 11 clusters:
     17 [
     18 {(1.00,1.00),[(1.00,1.00)]}, 
     19 {(1.00,3.00),[(1.00,3.00)]}, 
     20 {(2.00,6.00),[(2.00,6.00)]}, 
     21 {(3.00,2.00),[(3.00,2.00)]}, 
     22 {(4.00,3.00),[(4.00,3.00)]}, 
     23 {(7.00,1.00),[(7.00,1.00)]}, 
     24 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
     25 {(3.00,4.00),[(3.00,4.00)]}, 
     26 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}, 
     27 {(1.00,5.00),[(1.00,5.00)]}, 
     28 {(5.00,6.00),[(5.00,6.00)]}]
     29 
     30 10 clusters:
     31 [
     32 {(1.00,1.00),[(1.00,1.00)]}, 
     33 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
     34 {(1.00,3.00),[(1.00,3.00)]}, 
     35 {(3.00,2.00),[(3.00,2.00)]}, 
     36 {(4.00,3.00),[(4.00,3.00)]}, 
     37 {(7.00,1.00),[(7.00,1.00)]}, 
     38 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
     39 {(3.00,4.00),[(3.00,4.00)]}, 
     40 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}, 
     41 {(5.00,6.00),[(5.00,6.00)]}]
     42 
     43  9 clusters:
     44 [
     45 {(1.00,1.00),[(1.00,1.00)]}, 
     46 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
     47 {(1.00,3.00),[(1.00,3.00)]}, 
     48 {(7.00,1.00),[(7.00,1.00)]}, 
     49 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
     50 {(3.50,2.50),[(3.00,2.00), (4.00,3.00)]}, 
     51 {(3.00,4.00),[(3.00,4.00)]}, 
     52 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}, 
     53 {(5.00,6.00),[(5.00,6.00)]}]
     54 
     55  8 clusters:
     56 [
     57 {(1.00,1.00),[(1.00,1.00)]}, 
     58 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
     59 {(1.00,3.00),[(1.00,3.00)]}, 
     60 {(3.33,3.00),[(3.00,2.00), (4.00,3.00), (3.00,4.00)]}, 
     61 {(7.00,1.00),[(7.00,1.00)]}, 
     62 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
     63 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}, 
     64 {(5.00,6.00),[(5.00,6.00)]}]
     65 
     66  7 clusters:
     67 [
     68 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
     69 {(3.33,3.00),[(3.00,2.00), (4.00,3.00), (3.00,4.00)]}, 
     70 {(1.00,2.00),[(1.00,1.00), (1.00,3.00)]}, 
     71 {(7.00,1.00),[(7.00,1.00)]}, 
     72 {(7.00,5.50),[(7.00,6.00), (7.00,5.00)]}, 
     73 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}, 
     74 {(5.00,6.00),[(5.00,6.00)]}]
     75 
     76  6 clusters:
     77 [
     78 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
     79 {(3.33,3.00),[(3.00,2.00), (4.00,3.00), (3.00,4.00)]}, 
     80 {(1.00,2.00),[(1.00,1.00), (1.00,3.00)]}, 
     81 {(6.33,5.67),[(7.00,6.00), (7.00,5.00), (5.00,6.00)]}, 
     82 {(7.00,1.00),[(7.00,1.00)]}, 
     83 {(6.00,3.50),[(6.00,3.00), (6.00,4.00)]}]
     84 
     85  5 clusters:
     86 [
     87 {(6.20,4.80),[(6.00,3.00), (7.00,6.00), (7.00,5.00), (6.00,4.00), (5.00,6.00)]}, 
     88 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
     89 {(3.33,3.00),[(3.00,2.00), (4.00,3.00), (3.00,4.00)]}, 
     90 {(1.00,2.00),[(1.00,1.00), (1.00,3.00)]}, 
     91 {(7.00,1.00),[(7.00,1.00)]}]
     92 
     93  4 clusters:
     94 [
     95 {(6.20,4.80),[(6.00,3.00), (7.00,6.00), (7.00,5.00), (6.00,4.00), (5.00,6.00)]}, 
     96 {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}, 
     97 {(7.00,1.00),[(7.00,1.00)]}, 
     98 {(2.40,2.60),[(1.00,1.00), (3.00,2.00), (4.00,3.00), (3.00,4.00), (1.00,3.00)]}]
     99 
    100  3 clusters:
    101 [
    102 {(6.20,4.80),[(6.00,3.00), (7.00,6.00), (7.00,5.00), (6.00,4.00), (5.00,6.00)]}, 
    103 {(7.00,1.00),[(7.00,1.00)]}, 
    104 {(2.14,3.43),[(1.00,1.00), (2.00,6.00), (3.00,2.00), (4.00,3.00), (3.00,4.00), (1.00,3.00), (1.00,5.00)]}]
    View Code

    weka实现

     1 import java.util.ArrayList;
     2 import weka.clusterers.HierarchicalClusterer;
     3 import static weka.clusterers.HierarchicalClusterer.TAGS_LINK_TYPE;
     4 import weka.core.Attribute;
     5 import weka.core.Instance;
     6 import weka.core.Instances;
     7 import weka.core.SelectedTag;
     8 import weka.core.SparseInstance;
     9 
    10 public class WekaHierarchicalClustering {
    11     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2}, 
    12         {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
    13     private static final int M = DATA.length;  // number of points
    14     private static final int K = 3;            // number of clusters
    15 
    16     public static void main(String[] args) {
    17         Instances dataset = load(DATA);
    18         HierarchicalClusterer hc = new HierarchicalClusterer();
    19         hc.setLinkType(new SelectedTag(4, TAGS_LINK_TYPE));  // CENTROID
    20         hc.setNumClusters(3);
    21         try {
    22             hc.buildClusterer(dataset);
    23             for (Instance instance : dataset) {
    24                 System.out.printf("(%.0f,%.0f): %s%n", 
    25                         instance.value(0), instance.value(1), 
    26                         hc.clusterInstance(instance));
    27             }
    28         } catch (Exception e) {
    29             System.err.println(e);
    30         }
    31     }
    32     
    33     private static Instances load(double[][] data) {
    34         ArrayList<Attribute> attributes = new ArrayList<Attribute>();
    35         attributes.add(new Attribute("X"));
    36         attributes.add(new Attribute("Y"));
    37         Instances dataset = new Instances("Dataset", attributes, M);
    38         for (double[] datum : data) {
    39             Instance instance = new SparseInstance(2);
    40             instance.setValue(0, datum[0]);
    41             instance.setValue(1, datum[1]);
    42             dataset.add(instance);
    43         }
    44         return dataset;
    45     }
    46 }
    View Code

    结果-->

    (1,1): 0
    (1,3): 0
    (1,5): 0
    (2,6): 0
    (3,2): 0
    (3,4): 0
    (4,3): 0
    (5,6): 1
    (6,3): 1
    (6,4): 1
    (7,1): 2
    (7,5): 1
    (7,6): 1
    View Code

    weka画图

     1 import java.awt.BorderLayout;
     2 import java.awt.Container;
     3 import java.util.ArrayList;
     4 import javax.swing.JFrame;
     5 import weka.clusterers.HierarchicalClusterer;
     6 import static weka.clusterers.HierarchicalClusterer.TAGS_LINK_TYPE;
     7 import weka.core.Attribute;
     8 import weka.core.Instance;
     9 import weka.core.Instances;
    10 import weka.core.SelectedTag;
    11 import weka.core.SparseInstance;
    12 import weka.gui.hierarchyvisualizer.HierarchyVisualizer;
    13 
    14 public class WekaHierarchicalClustering2 {
    15     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2}, 
    16         {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
    17     private static final int M = DATA.length;  // number of points
    18     private static final int K = 3;            // number of clusters
    19 
    20     public static void main(String[] args) {
    21         Instances dataset = load(DATA);
    22         HierarchicalClusterer hc = new HierarchicalClusterer();
    23         hc.setLinkType(new SelectedTag(4, TAGS_LINK_TYPE));  // CENTROID
    24         hc.setNumClusters(1);
    25         try {
    26             hc.buildClusterer(dataset);
    27             for (Instance instance : dataset) {
    28                 System.out.printf("(%.0f,%.0f): %s%n", 
    29                         instance.value(0), instance.value(1), 
    30                         hc.clusterInstance(instance));
    31             }
    32             displayDendrogram(hc.graph());
    33         } catch (Exception e) {
    34             System.err.println(e);
    35         }
    36     }
    37     
    38     private static Instances load(double[][] data) {
    39         ArrayList<Attribute> attributes = new ArrayList<Attribute>();
    40         attributes.add(new Attribute("X"));
    41         attributes.add(new Attribute("Y"));
    42         Instances dataset = new Instances("Dataset", attributes, M);
    43         for (double[] datum : data) {
    44             Instance instance = new SparseInstance(2);
    45             instance.setValue(0, datum[0]);
    46             instance.setValue(1, datum[1]);
    47             dataset.add(instance);
    48         }
    49         return dataset;
    50     }
    51     
    52     public static void displayDendrogram(String graph) {
    53         JFrame frame = new JFrame("Dendrogram");
    54         frame.setSize(500, 400);
    55         frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
    56         Container pane = frame.getContentPane();
    57         pane.setLayout(new BorderLayout());
    58         pane.add(new HierarchyVisualizer(graph));
    59         frame.setVisible(true);
    60     }
    61 }
    View Code

     

    K-均值聚类

    • 需求:同上
    • 实现:从数据集中选k个点创建k个聚类,其余点添加到最近的聚类中,重新计算中心

    KMeans.java(普通实现)

     1 import java.util.HashSet;
     2 import java.util.Random;
     3 import java.util.Set;
     4 
     5 public class KMeans {
     6     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2},
     7             {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
     8     private static final int M = DATA.length;
     9     private static final int K = 3;
    10     private static HashSet<Point> points;
    11     private static HashSet<Cluster> clusters = new HashSet();
    12     private static Random RANDOM = new Random();
    13 
    14     public static void main(String[] args){
    15         points = load(DATA);
    16 
    17         int i0 = RANDOM.nextInt(M);
    18         Point p = new Point(DATA[i0][0],DATA[i0][1]);
    19         points.remove(p);
    20 
    21         HashSet<Point> initSet = new HashSet();
    22         initSet.add(p);
    23 
    24         for(int i = 1; i < K; i ++){
    25             p = farthestFrom(initSet);
    26             initSet.add(p);
    27             points.remove(p);
    28         }
    29 
    30         for(Point point:initSet){
    31             Cluster cluster = new Cluster(point);
    32             clusters.add(cluster);
    33         }
    34 
    35         for(Point point:points){
    36             Cluster cluster = closestTo(point);
    37             cluster.add(point);
    38             cluster.recomputeCentroid();
    39         }
    40         System.out.println(clusters);
    41     }
    42 
    43     private static HashSet<Point> load(double[][] data) {
    44         HashSet<Point> points = new HashSet();
    45         for (double[] datum : DATA) {
    46             points.add(new Point(datum[0], datum[1]));
    47         }
    48         return points;
    49     }
    50 
    51     // return the cluster whose centroid is closet to the specified point
    52     private static Cluster closestTo(Point point){
    53         double minDist = Double.POSITIVE_INFINITY;
    54         Cluster c = null;
    55         for(Cluster cluster:clusters){
    56             double d = distance2(cluster.getCentroid(),point);
    57             if(d < minDist){
    58                 minDist = d;
    59                 c = cluster;
    60             }
    61         }
    62         return c;
    63     }
    64 
    65     // return the point that is farthest from the specified set
    66     private static Point farthestFrom(Set<Point> set){
    67         Point p = null;
    68         double maxDist = 0.0;
    69         for(Point point:points){
    70             if(set.contains(point)){
    71                 continue;
    72             }
    73             double d = dist(point,set);
    74             if(d > maxDist){
    75                 p = point;
    76                 maxDist = d;
    77             }
    78         }
    79         return p;
    80     }
    81 
    82     // return the distance from p to the nearest point in the set
    83     public static double dist(Point p, Set<Point> set){
    84         double minDist = Double.POSITIVE_INFINITY;
    85         for(Point point:set){
    86             double d = distance2(p,point);
    87             minDist = (d < minDist ? d : minDist);
    88         }
    89         return minDist;
    90     }
    91 
    92     public static double distance2(Point p, Point q){
    93         double dx = p.getX() - q.getX();
    94         double dy = p.getY() - q.getY();
    95         return dx*dx + dy*dy;
    96     }
    97 }
    View Code

    [{(2.40,2.60),[(1.00,1.00), (1.00,3.00), (3.00,2.00), (4.00,3.00), (3.00,4.00)]},
    {(6.33,4.17),[(6.00,3.00), (7.00,6.00), (6.00,4.00), (7.00,5.00), (7.00,1.00), (5.00,6.00)]},
    {(1.50,5.50),[(2.00,6.00), (1.00,5.00)]}]

    KMeans.java(Weka 实现)

     1 import java.util.ArrayList;
     2 import weka.clusterers.SimpleKMeans;
     3 import weka.core.Attribute;
     4 import weka.core.Instance;
     5 import weka.core.Instances;
     6 import weka.core.SparseInstance;
     7 
     8 public class KMeans {
     9     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2}, 
    10         {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
    11     private static final int M = DATA.length;  // number of points
    12     private static final int K = 3;            // number of clusters
    13 
    14     public static void main(String[] args) {
    15         Instances dataset = load(DATA);
    16         SimpleKMeans skm = new SimpleKMeans();
    17         System.out.printf("%d clusters:%n", K);
    18         try {
    19             skm.setNumClusters(K);
    20             skm.buildClusterer(dataset);
    21             for (Instance instance : dataset) {
    22                 System.out.printf("(%.0f,%.0f): %s%n", 
    23                         instance.value(0), instance.value(1), 
    24                         skm.clusterInstance(instance));
    25             }
    26         } catch (Exception e) {
    27             System.err.println(e);
    28         }
    29     }
    30     
    31     private static Instances load(double[][] data) {
    32         ArrayList<Attribute> attributes = new ArrayList<Attribute>();
    33         attributes.add(new Attribute("X"));
    34         attributes.add(new Attribute("Y"));
    35         Instances dataset = new Instances("Dataset", attributes, M);
    36         for (double[] datum : data) {
    37             Instance instance = new SparseInstance(2);
    38             instance.setValue(0, datum[0]);
    39             instance.setValue(1, datum[1]);
    40             dataset.add(instance);
    41         }
    42         return dataset;
    43     }
    44 }
    View Code

    结果-->

    (1,1): 1
    (1,3): 1
    (1,5): 0
    (2,6): 0
    (3,2): 1
    (3,4): 0
    (4,3): 0
    (5,6): 0
    (6,3): 2
    (6,4): 2
    (7,1): 2
    (7,5): 2
    (7,6): 2
    View Code

    KMeansPlusPlus.java(Apache Common Math 实现)

     1 import java.util.ArrayList;
     2 import java.util.List;
     3 import org.apache.commons.math3.ml.clustering.CentroidCluster;
     4 import org.apache.commons.math3.ml.clustering.DoublePoint;
     5 import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
     6 import org.apache.commons.math3.ml.distance.EuclideanDistance;
     7 
     8 public class KMeansPlusPlus {
     9     private static final double[][] DATA = {{1,1}, {1,3}, {1,5}, {2,6}, {3,2}, 
    10         {3,4}, {4,3}, {5,6}, {6,3}, {6,4}, {7,1}, {7,5}, {7,6}};
    11     private static final int M = DATA.length;  // number of points
    12     private static final int K = 3;  // number of clusters
    13     private static final int MAX = 100;  // maximum number of iterations
    14     private static final EuclideanDistance ED = new EuclideanDistance();
    15     
    16     public static void main(String[] args) {
    17         List<DoublePoint> points = load(DATA);
    18         KMeansPlusPlusClusterer<DoublePoint> clusterer;
    19         clusterer = new KMeansPlusPlusClusterer(K, MAX, ED);
    20         List<CentroidCluster<DoublePoint>> clusters = clusterer.cluster(points);
    21         
    22         for (CentroidCluster<DoublePoint> cluster : clusters) {
    23             System.out.println(cluster.getPoints());
    24         }
    25     }
    26     
    27     private static List<DoublePoint> load(double[][] data) {
    28         List<DoublePoint> points = new ArrayList(M);
    29         for (double[] pair : data) {
    30             points.add(new DoublePoint(pair));            
    31         }
    32         return points;
    33     } 
    34 }
    View Code

    [[5.0, 6.0], [6.0, 3.0], [6.0, 4.0], [7.0, 5.0], [7.0, 6.0]]
    [[1.0, 1.0], [1.0, 3.0], [1.0, 5.0], [2.0, 6.0], [3.0, 2.0], [3.0, 4.0], [4.0, 3.0]]
    [[7.0, 1.0]]

    仿射传播聚类

    • 需求:同上
    • 实现:
    • 特点:不同于KMeans,聚类个数k不需事先确定,
     1 public class AffinityPropagation {
     2     private static double[][] x = {{1,2}, {2,3}, {4,1}, {4,4}, {5,3}};
     3     private static int n = x.length;                 // number of points
     4     private static double[][] s = new double[n][n];  // similarities
     5     private static double[][] r = new double[n][n];  // responsibilities
     6     private static double[][] a = new double[n][n];  // availabilities
     7     private static final int ITERATIONS = 10;
     8     private static final double DAMPER = 0.5;
     9 
    10     public static void main(String[] args) {
    11         initSimilarities();
    12         for (int i = 0; i < ITERATIONS; i++) {
    13             updateResponsibilities();
    14             updateAvailabilities();
    15         }
    16         printResults();
    17     }
    18     
    19     private static void initSimilarities() {
    20         double sum = 0;
    21         for (int i = 0; i < n; i++) {
    22             for (int j = 0; j < i; j++) {
    23                 sum += s[i][j] = s[j][i] = negSqEuclidDist(x[i], x[j]);
    24             }
    25         }
    26         double average = 2*sum/(n*n - n);  // average of s[i][j] for j < i
    27         for (int i = 0; i < n; i++) {
    28             s[i][i] = average;
    29         }
    30     }
    31     
    32     private static void updateResponsibilities() {
    33         for (int i = 0; i < n; i++) {
    34             for (int k = 0; k < n; k++) {
    35                 double oldValue = r[i][k];
    36                 double max = Double.NEGATIVE_INFINITY;
    37                 for (int j = 0; j < n; j++) {
    38                     if (j != k) {
    39                         max = Math.max(max, a[i][j] + s[i][j]);
    40                     }
    41                 }
    42                 double newValue = s[i][k] - max;
    43                 r[i][k] = DAMPER*oldValue + (1 - DAMPER)*newValue;
    44             }
    45         }
    46     }
    47     
    48     private static void updateAvailabilities() {
    49         for (int i = 0; i < n; i++) {
    50             for (int k = 0; k < n; k++) {
    51                 double oldValue = a[i][k];
    52                 double newValue = Math.min(0, r[k][k] + sumOfPos(i,k));
    53                 if (k == i) {
    54                     newValue = sumOfPos(k,k);
    55                 }
    56                 a[i][k] = DAMPER*oldValue + (1 - DAMPER)*newValue;
    57             }
    58         }
    59     }
    60     
    61     /*  Returns the negative square of the Euclidean distance from x to y.
    62     */
    63     private static double negSqEuclidDist(double[] x, double[] y) {
    64         double d0 = x[0] - y[0];
    65         double d1 = x[1] - y[1];
    66         return -(d0*d0 + d1*d1);
    67     }
    68     
    69     /*  Returns the sum of the positive r[j][k] excluding r[i][k] and r[k][k].
    70     */
    71     private static double sumOfPos(int i, int k) {
    72         double sum = 0;
    73         for (int j = 0; j < n; j++) {
    74             if (j != i && j != k) {
    75                 sum += Math.max(0, r[j][k]);
    76             }
    77         }
    78         return sum;
    79     }
    80     
    81     private static void printResults() {
    82         for (int i = 0; i < n; i++) {
    83             double max = a[i][0] + r[i][0];
    84             int k = 0;
    85             for (int j = 1; j < n; j++) {
    86                 double arij = a[i][j] + r[i][j];
    87                 if (arij > max) {
    88                     max = arij;
    89                     k = j;
    90                 }
    91             }
    92             System.out.printf("point %d has exemplar point %d%n", i, k);
    93         }
    94     }
    95 }
    View Code

    point 0 has exemplar point 1
    point 1 has exemplar point 1
    point 2 has exemplar point 4
    point 3 has exemplar point 4
    point 4 has exemplar point 4

    参考

    https://blog.csdn.net/xzfreewind/article/details/73770327

  • 相关阅读:
    c++标准库容器【转】
    C++命名空间的解释 【转】
    [转载]定义、公理、定理、推论、命题和引理的区别
    待读论文
    矩阵分解 Matrix Factorization (RegularSVD) 实验总结
    Predicting the Next Location: A Recurrent Model with Spatial and Temporal Contexts AAAI2016
    Discovering Urban Functional Zones Using Latent Activity Trajectories TKDE 2015
    numpy
    python 编程 规范
    深度学习
  • 原文地址:https://www.cnblogs.com/cxc1357/p/14692230.html
Copyright © 2011-2022 走看看