zoukankan      html  css  js  c++  java
  • K-means算法Java实现

    public class KMeansCluster { 

            private int k;//簇的个数 
            private int num = 100000;//迭代次数 
            private List<double> datas;//原始样本集 
            private String address;//样本集路径 
            private List<point> data = new ArrayList<point>(); 
            private AbstractDistance distance = new AbstractDistance() { 
                @Override 
                public double getDis(Point p1, Point p2) { 
                    //欧几里德距离 
                    return Math.sqrt(Math.pow(p1.getX() - p2.getX(), 2) + Math.pow(p1.getY() - p2.getY(), 2)); 
                
            }; 
           
            public KMeansCluster(int k, int num, String address) { 
                this.k = k; 
                this.num = num; 
                this.address = address; 
            
           
            public KMeansCluster(int k, String address) { 
                this.k = k; 
                this.address = address; 
            
           
            public KMeansCluster(int k, List<double> datas) { 
                this.k = k; 
                this.datas = datas; 
            
           
            public KMeansCluster(int k, int num, List<double> datas) { 
                this.k = k; 
                this.num = num; 
                this.datas = datas; 
            
           
            private void check() { 
                if (k == 0
                    throw new IllegalArgumentException("k must be the number > 0"); 
           
                if (address == null && datas == null
                    throw new IllegalArgumentException("program can't get real data"); 
            
           
            /**
             * 初始化数据
             *
             * @throws java.io.FileNotFoundException
             */ 
            public void init() throws FileNotFoundException { 
                check(); 
                //读取文件,init data 
                //处理原始数据 
                for (int i = 0, j = datas.size(); i < j; i++) 
                    data.add(new Point(i, datas.get(i), 0)); 
            
           
            /**
             * 第一次随机选取中心点
             *
             * @return
             */ 
            public Set<point> chooseCenter() { 
                Set<point> center = new HashSet<point>(); 
                Random ran = new Random(); 
                int roll = 0
                while (center.size() < k) { 
                    roll = ran.nextInt(data.size()); 
                    center.add(data.get(roll)); 
                
                return center; 
            
           
            /**
             * @param center
             * @return
             */ 
            public List<cluster> prepare(Set<point> center) { 
                List<cluster> cluster = new ArrayList<cluster>(); 
                Iterator<point> it = center.iterator(); 
                int id = 0
                while (it.hasNext()) { 
                    Point p = it.next(); 
                    if (p.isBeyond()) { 
                        Cluster c = new Cluster(id++, p); 
                        c.addPoint(p); 
                        cluster.add(c); 
                    else 
                        cluster.add(new Cluster(id++, p)); 
                
                return cluster; 
            
           
            /**
             * 第一次运算,中心点为样本值
             *
             * @param center
             * @param cluster
             * @return
             */ 
            public List<cluster> clustering(Set<point> center, List<cluster> cluster) { 
                Point[] p = center.toArray(new Point[0]); 
                TreeSet<distence> distence = new TreeSet<distence>();//存放距离信息 
                Point source; 
                Point dest; 
                boolean flag = false
                for (int i = 0, n = data.size(); i < n; i++) { 
                    distence.clear(); 
                    for (int j = 0; j < center.size(); j++) { 
                        if (center.contains(data.get(i))) 
                            break
           
                        flag = true
                        // 计算距离 
                        source = data.get(i); 
                        dest = p[j]; 
                        distence.add(new Distence(source, dest, distance)); 
                    
                    if (flag == true) { 
                        Distence min = distence.first(); 
                        for (int m = 0, k = cluster.size(); m < k; m++) { 
                            if (cluster.get(m).getCenter().equals(min.getDest())) 
                                cluster.get(m).addPoint(min.getSource()); 
           
                        
                    
                    flag = false
                
           
                return cluster; 
            
           
            /**
             * 迭代运算,中心点为簇内样本均值
             *
             * @param cluster
             * @return
             */ 
            public List<cluster> cluster(List<cluster> cluster) { 
        //        double error; 
                Set<point> lastCenter = new HashSet<point>(); 
                for (int m = 0; m < num; m++) { 
        //            error = 0; 
                    Set<point> center = new HashSet<point>(); 
                    // 重新计算聚类中心 
                    for (int j = 0; j < k; j++) { 
                        List<point> ps = cluster.get(j).getMembers(); 
                        int size = ps.size(); 
                        if (size < 3) { 
                            center.add(cluster.get(j).getCenter()); 
                            continue
                        
                        // 计算距离 
                        double x = 0.0, y = 0.0
                        for (int k1 = 0; k1 < size; k1++) { 
                            x += ps.get(k1).getX(); 
                            y += ps.get(k1).getY(); 
                        
                        //得到新的中心点 
                        Point nc = new Point(-1, x / size, y / size, false); 
                        center.add(nc); 
                    
                    if (lastCenter.containsAll(center))//中心点不在变化,退出迭代 
                        break
                    lastCenter = center; 
                    // 迭代运算 
                    cluster = clustering(center, prepare(center)); 
        //            for (int nz = 0; nz < k; nz++) { 
        //                error += cluster.get(nz).getError();//计算误差 
        //            } 
                
                return cluster; 
            
           
            /**
             * 输出聚类信息到控制台
             *
             * @param cs
             */ 
            public void out2console(List<cluster> cs) { 
                for (int i = 0; i < cs.size(); i++) { 
                    System.out.println("No." + (i + 1) + " cluster:"); 
                    Cluster c = cs.get(i); 
                    List<point> p = c.getMembers(); 
                    for (int j = 0; j < p.size(); j++) { 
                        System.out.println(" " + p.get(j).getX() + " "); 
                    
                    System.out.println(); 
                
            
        }
  • 相关阅读:
    【Python入门自学笔记专辑】——面向对象编程-实例方法11.3.6
    最简单的轮播广告(原生JS)
    (转)JavaScript一:为什么学习JavaScript?
    JAVASCRIPT中经典面试题
    使用AngularJS实现简单:全选和取消全选功能
    canvas绘制经典折线图(一)
    前端总结
    PHP如何连接MySQL数据库
    PHP预定义变量
    PHP语法
  • 原文地址:https://www.cnblogs.com/wcLT/p/4604028.html
Copyright © 2011-2022 走看看