zoukankan      html  css  js  c++  java
  • 3.聚类–K-means的Java实现

    K-means的步骤

    输入: 含n 个样本的数据集,簇的数据K

    输出: K 个簇

    算法步骤:

    1.初始化K个簇类中心C1,C2,-……Ck (通常随机选择)

    2.repeat 步骤3,4

    3,将数据集中的每个样本分配到与之最近的中心Ci所在的簇Cj ;

    4. 更新聚类中心Ci,即计算各个簇的样本均值;

    5.直到样本分配不在改变

    上代码:

    import java.lang.annotation.ElementType;
    import java.lang.annotation.Retention;
    import java.lang.annotation.RetentionPolicy;
    import java.lang.annotation.Target;
    
    /**
     * 在对象的属性上标注此注释,
     * 表示纳入kmeans算法,仅支持数值类属性
     * @author 阿飞哥
     */
    @Retention(RetentionPolicy.RUNTIME)
    @Target(ElementType.FIELD)
    public @interface KmeanField {
    }
    
    

    import java.lang.annotation.Annotation;
    import java.lang.reflect.Field;
    import java.lang.reflect.Method;
    import java.util.ArrayList;
    import java.util.List;
    
    /**
     * 
     * @author 阿飞哥
     * 
     */
    public class Kmeans<T> {
    
        /**
         * 所有数据列表
         */
        private List<T> players = new ArrayList<T>();
    
        /**
         * 数据类别
         */
        private Class<T> classT;
    
        /**
         * 初始化列表
         */
        private List<T> initPlayers;
    
        /**
         * 需要纳入kmeans算法的属性名称
         */
        private List<String> fieldNames = new ArrayList<String>();
    
        /**
         * 分类数
         */
        private int k = 1;
    
        public Kmeans() {
    
        }
    
        /**
         * 初始化列表
         * 
         * @param list
         * @param k
         */
        public Kmeans(List<T> list, int k) {
            this.players = list;
            this.k = k;
            T t = list.get(0);
            this.classT = (Class<T>) t.getClass();
            Field[] fields = this.classT.getDeclaredFields();
            System.out.println("fields---------------------------------------------="+fields.length);
            for (int i = 0; i < fields.length; i++) {
                Annotation kmeansAnnotation = fields[i]
                        .getAnnotation(KmeanField.class);
                if (kmeansAnnotation != null) {
                    fieldNames.add(fields[i].getName());
                    System.out.println("fieldNames.add"+ fields[i].getName());
                    
                }
    
            }
    
            initPlayers = new ArrayList<T>();
            for (int i = 0; i < k; i++) {
                initPlayers.add(players.get(i));
            }
        }
    
        public List<T>[] comput() {
            List<T>[] results = new ArrayList[k];
    
            boolean centerchange = true;
            while (centerchange) {
                centerchange = false;
                for (int i = 0; i < k; i++) {
                    results[i] = new ArrayList<T>();
                }
                for (int i = 0; i < players.size(); i++) {
                    T p = players.get(i);
                    double[] dists = new double[k];
                    for (int j = 0; j < initPlayers.size(); j++) {
                        T initP = initPlayers.get(j);
                        /* 计算距离 */
                        double dist = distance(initP, p);
    //                    double dist = 1.0;
    //                    double dist = LevenshteinDistance.levenshteinDistance(initP, p);
    //                    System.out.println("dist="+dist);
                    
                        dists[j] = dist;
                    }
    
                    int dist_index = computOrder(dists);
    //                System.out.println("dist_index="+dist_index);
                    results[dist_index].add(p);
                }
                
    //            System.out.println("results[0].size()="+results[0].size());
    
                for (int i = 0; i < k; i++) { // 在每一个簇中寻找中心点
                    T player_new = findNewCenter(results[i]);
    //                System.out.println( "results[i]"+i+"----"+k+"---===="+results[i].size() +"===="+player_new.toString());
                    T player_old = initPlayers.get(i);
                    if (!IsPlayerEqual(player_new, player_old)) {
                        centerchange = true;
                        initPlayers.set(i, player_new);
                    }
                }
            }
    //        System.out.println( "results+"+results.length);
            return results;
        }
    
        /**
         * 比较是否两个对象是否属性一致
         * 
         * @param p1
         * @param p2
         * @return
         */
        public boolean IsPlayerEqual(T p1, T p2) {
            if (p1 == p2) {
                return true;
            }
            if (p1 == null || p2 == null) {
                return false;
            }
    
            
    
            boolean flag = true;
            try {
                for (int i = 0; i < fieldNames.size(); i++) {
                    
                    String fieldName=fieldNames.get(i);
                    String getName = "get"
                            + fieldName.substring(0, 1).toUpperCase()
                            + fieldName.substring(1);        
    //                System.out.println(fieldNames);
                    Object value1 = invokeMethod(p1,getName,null);
                    Object value2 = invokeMethod(p2,getName,null);
                    if (!value1.equals(value2)) {
                        flag = false;
                        break;
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
                flag = false;
            }
    
            return flag;
        }
    
        /**
         * 得到新聚类中心对象
         * 
         * @param ps
         * @return
         */
        public T findNewCenter(List<T> ps) {
            try {
                T t = classT.newInstance();
                if (ps == null || ps.size() == 0) {
                    return t;
                }
    
                double[] ds = new double[fieldNames.size()];
                for (T vo : ps) {
                    for (int i = 0; i < fieldNames.size(); i++) {
                        String fieldName=fieldNames.get(i);
                        String getName = "get"
                                + fieldName.substring(0, 1).toUpperCase()
                                + fieldName.substring(1);
                        Object obj=invokeMethod(vo,getName,null);
                        Double fv=(obj==null?0:Double.parseDouble(obj+""));
                        ds[i] += fv;
                    }
    
                }
                
    //            System.out.println("-----------------");
                for (int i = 0; i < fieldNames.size(); i++) {
                    ds[i] = ds[i] / ps.size();    // 平均距离
                    String fieldName = fieldNames.get(i);
                    
                    /* 给对象设值 */
                    String setName = "set"
                            + fieldName.substring(0, 1).toUpperCase()
                            + fieldName.substring(1);
    
    //                invokeMethod(t,setName,new Class[]{double.class},ds[i]);
                    System.out.println("ds[i] ++="+ds[i]+"----ps.size()"+ps.size());
                    invokeMethod(t,setName,new Class[]{double.class},ds[i]);
    
                }
                
                
                
                return t;
            } catch (Exception ex) {
                ex.printStackTrace();
            }
            return null;
    
        }
    
        /**
         * 得到最短距离,并返回最短距离索引
         * 
         * @param dists
         * @return
         */
        public int computOrder(double[] dists) {
            double min = 0;
            int index = 0;
            for (int i = 0; i < dists.length - 1; i++) {
                double dist0 = dists[i];
                if (i == 0) {
                    min = dist0;
                    index = 0;
                }
                double dist1 = dists[i + 1];
                if (min > dist1) {
                    min = dist1;
                    index = i + 1;
                }
            }
    
            return index;
        }
    
        /**
         * 计算距离(相似性) 采用欧几里得算法
         * 
         * @param p0
         * @param p1
         * @return
         */
        public double distance(T p0, T p1) {
            double dis = 0;
            try {
    
                for (int i = 0; i < fieldNames.size(); i++) {
                    String fieldName = fieldNames.get(i);
                    String getName = "get"
                            + fieldName.substring(0, 1).toUpperCase()
                            + fieldName.substring(1);
                    
    //                System.out.println("fieldNames-----="+fieldNames.size());
                    Double field0Value=Double.parseDouble(invokeMethod(p0,getName,null)+"");
                    Double field1Value=Double.parseDouble(invokeMethod(p1,getName,null)+"");
    //                System.out.println("field0Value="+field0Value);
                    dis += Math.pow(field0Value - field1Value, 2); 
                    
    
                    
                    
                }
            
            } catch (Exception ex) {
                ex.printStackTrace();
            }
            return Math.sqrt(dis);
    
        }
        
        /*------公共方法-----*/
        public Object invokeMethod(Object owner, String methodName,Class[] argsClass,
                Object... args) {
            Class ownerClass = owner.getClass();
            
            try {
                Method method=ownerClass.getDeclaredMethod(methodName,argsClass);
                
                return method.invoke(owner, args);
            } catch (SecurityException e) {
                e.printStackTrace();
            } catch (NoSuchMethodException e) {
                e.printStackTrace();
            } catch (Exception ex) {
                ex.printStackTrace();
            }
    
            return null;
        }
    
    }
    
    

    public class Player {
    
    private int id;
    //@KmeanField
    private String name;
    
    private int age;
    
    /* 得分 */
    @KmeanField
    private double goal;
    
    /* 助攻 */
    //@KmeanField
    private double assists;
    
    /* 篮板 */
    //@KmeanField
    private double backboard;
    
    /* 抢断 */
    //@KmeanField
    private double steals;
    
    public int getId() {
        return id;
    }
    
    public void setId(int id) {
        this.id = id;
    }
    
    public String getName() {
        return name;
    }
    
    public void setName(String name) {
        this.name = name;
    }
    
    public int getAge() {
        return age;
    }
    
    public void setAge(int age) {
        this.age = age;
    }
    
    public double getGoal() {
        return goal;
    }
    
    public void setGoal(double goal) {
        this.goal = goal;
    }
    
    public double getAssists() {
        return assists;
    }
    
    public void setAssists(double assists) {
        this.assists = assists;
    }
    
    public double getBackboard() {
        return backboard;
    }
    
    public void setBackboard(double backboard) {
        this.backboard = backboard;
    }
    
    public double getSteals() {
        return steals;
    }
    
    public void setSteals(double steals) {
        this.steals = steals;
    }
    
    @Override
        public String toString() {
            // TODO Auto-generated method stub
            return name;
        }
    }
    
    

     
    import java.util.ArrayList;
    import java.util.List;
    import java.util.Random;
    
    public class TestMain {
    
        public static void main(String[] args) {
           List<Player> listPlayers=new ArrayList<Player>();
            
            for(int i=0;i<15;i++){
                
                Player p1=new Player();
                p1.setName("afei-"+i);
                p1.setAssists(i);
                p1.setBackboard(i);
                
                //p1.setGoal(new Random(100*i).nextDouble());
                p1.setGoal(i*10);
                p1.setSteals(i);
                //listPlayers.add(p1);    
            }
            
            Player p1=new Player();
            p1.setName("afei1");
            p1.setGoal(1);
            p1.setAssists(8);
            listPlayers.add(p1);
           
            Player p2=new Player();
            p2.setName("afei2");
            p2.setGoal(2);
            listPlayers.add(p2);
            
             Player p3=new Player();
            p3.setName("afei3");
            p3.setGoal(3);
            listPlayers.add(p3);
            
             Player p4=new Player();
            p4.setName("afei4");
            p4.setGoal(7);
            listPlayers.add(p4);
            
             Player p5=new Player();
            p5.setName("afei5");
            p5.setGoal(8);
            listPlayers.add(p5);
            
             Player p6=new Player();
            p6.setName("afei6");
            p6.setGoal(25);
            listPlayers.add(p6);
            
             Player p7=new Player();
            p7.setName("afei7");
            p7.setGoal(26);
            listPlayers.add(p7);
            
             Player p8=new Player();
            p8.setName("afei8");
            p8.setGoal(27);
            listPlayers.add(p8);
            
             Player p9=new Player();
            p9.setName("afei9");
            p9.setGoal(28);
            listPlayers.add(p9);
            
            
            Kmeans<Player> kmeans = new Kmeans<Player>(listPlayers,2);
            List<Player>[] results = kmeans.comput();
            for (int i = 0; i < results.length; i++) {
                System.out.println("===========类别" + (i + 1) + "================");
                List<Player> list = results[i];
                for (Player p : list) {
                    System.out.println(p.getName() + "--->"
                            + p.getGoal() + "," + p.getAssists() + ","
                            + p.getSteals() + "," + p.getBackboard());
                }
            }
            
            
            
          
        }
    
    }
    
    

    源码:https://github.com/chaoren399/dkdemo/tree/master/kmeans/src

  • 相关阅读:
    大数据Hadoop第二周——配置新的节点DataNode及ip地址
    vue环境搭建详细步骤
    苹果电脑Mac系统如何下载安装谷歌Chrome浏览器
    点云的基本特征和描述
    ModuleNotFoundError: No module named 'rospkg'
    ROS的多传感器时间同步机制Time Synchronizer
    Spring Cloud 2020 版本重大变革,更好的命名方式!
    Spring MVC 接收请求参数所有方式总结!
    阿里为什么不用 Zookeeper 做服务发现?
    微服务之间最佳调用方式是什么?
  • 原文地址:https://www.cnblogs.com/chaoren399/p/5006563.html
Copyright © 2011-2022 走看看