zoukankan      html  css  js  c++  java
  • Java 机器学习之K-mean

    1、Main.java

    package com.vue.demo.kmeans;
    
    import java.util.ArrayList;
    import java.util.Set;
    
    public class Main {
     
        public static void main(String[] args) {
            ArrayList<float[]> dataSet = new ArrayList<float[]>();
     
            dataSet.add(new float[] { 1, 2, 3 });
            dataSet.add(new float[] { 3, 3, 3 });
            dataSet.add(new float[] { 3, 4, 4});
            dataSet.add(new float[] { 5, 6, 5});
            dataSet.add(new float[] { 8, 9, 6});
            dataSet.add(new float[] { 4, 5, 4});
            dataSet.add(new float[] { 6, 4, 2});
            dataSet.add(new float[] { 3, 9, 7});
            dataSet.add(new float[] { 5, 9, 8});
            dataSet.add(new float[] { 4, 2, 10});
            dataSet.add(new float[] { 1, 9, 12});
            dataSet.add(new float[] { 7, 8, 112});
            dataSet.add(new float[] { 7, 8, 4});
     
            KMeansRun kRun =new KMeansRun(3, dataSet);
     
            Set<Cluster> clusterSet = kRun.run();
            System.out.println("单次迭代运行次数:"+kRun.getIterTimes());
            for (Cluster cluster : clusterSet) {
                System.out.println(cluster);
            }
        }
    }
    View Code

    2、Point.java

    package com.vue.demo.kmeans;
    
    import com.alibaba.fastjson.JSONObject;
    
    public class Point {
        private float[] localArray;
        private int id;
        private int clusterId;  // 标识属于哪个类中心。
        private float dist;     // 标识和所属类中心的距离。
        private Point clusterPoint;//中心点信息
    
        public Point getClusterPoint() {
            return clusterPoint;
        }
    
        public void setClusterPoint(Point clusterPoint) {
            this.clusterPoint = clusterPoint;
        }
    
        public float[] getLocalArray() {
            return localArray;
        }
    
        public void setLocalArray(float[] localArray) {
            this.localArray = localArray;
        }
    
        public int getClusterId() {
            return clusterId;
        }
    
    
    
        public Point(int id, float[] localArray) {
            this.id = id;
            this.localArray = localArray;
        }
     
        public Point(float[] localArray) {
            this.id = -1; //表示不属于任意一个类
            this.localArray = localArray;
        }
     
        public float[] getlocalArray() {
            return localArray;
        }
     
        public int getId() {
            return id;
        }
     
        public void setClusterId(int clusterId) {
            this.clusterId = clusterId;
        }
     
        public int getClusterid() {
            return clusterId;
        }
     
        public float getDist() {
            return dist;
        }
     
        public void setDist(float dist) {
            this.dist = dist;
        }
     
        @Override
        public String toString() {
            return JSONObject.toJSONString(this);
        }
    
        public void setId(int id) {
            this.id = id;
        }
    
    
        @Override
        public boolean equals(Object obj) {
            if (obj == null || getClass() != obj.getClass())
                return false;
     
            Point point = (Point) obj;
            if (point.localArray.length != localArray.length)
                return false;
     
            for (int i = 0; i < localArray.length; i++) {
                if (Float.compare(point.localArray[i], localArray[i]) != 0) {
                    return false;
                }
            }
            return true;
        }
     
        @Override
        public int hashCode() {
            float x = localArray[0];
            float y = localArray[localArray.length - 1];
            long temp = x != +0.0d ? Double.doubleToLongBits(x) : 0L;
            int result = (int) (temp ^ (temp >>> 32));
            temp = y != +0.0d ? Double.doubleToLongBits(y) : 0L;
            result = 31 * result + (int) (temp ^ (temp >>> 32));
            return result;
        }
    }
    View Code

    3、KMeansRun.java

    package com.vue.demo.kmeans;
    
    import java.util.*;
    
    public class KMeansRun {  
        private int kNum;                             //簇的个数
        private int iterNum = 10;                     //迭代次数
     
        private int iterMaxTimes = 100000;            //单次迭代最大运行次数
        private int iterRunTimes = 0;                 //单次迭代实际运行次数
        private float disDiff = (float) 0.01;         //单次迭代终止条件,两次运行中类中心的距离差
     
        private List<float[]> original_data =null;    //用于存放,原始数据集  
        private static List<Point> pointList = null;  //用于存放,原始数据集所构建的点集
        private DistanceCompute disC = new DistanceCompute();
        private int len = 0;                          //用于记录每个数据点的维度
     
        public KMeansRun(int k, List<float[]> original_data) {
            this.kNum = k;
            this.original_data = original_data;
            this.len = original_data.get(0).length; 
            //检查规范
            check();
            //初始化点集。
            init();
        }
     
        /**
         * 检查规范
         */
        private void check() {
            if (kNum == 0){
                throw new IllegalArgumentException("k must be the number > 0");  
            }
            if (original_data == null){
                throw new IllegalArgumentException("program can't get real data");
            }
        } 
     
        /** 
         * 初始化数据集,把数组转化为Point类型。
         */
        private void init() {
            pointList = new ArrayList<Point>();
            for (int i = 0, j = original_data.size(); i < j; i++){
                pointList.add(new Point(i, original_data.get(i)));
            }
        }
     
        /** 
         * 随机选取中心点,构建成中心类。
         */  
        private Set<Cluster> chooseCenterCluster() {
            Set<Cluster> clusterSet = new HashSet<Cluster>();
            Random random = new Random();
            for (int id = 0; id < kNum; ) {
                Point point = pointList.get(random.nextInt(pointList.size()));
                // 用于标记是否已经选择过该数据。
                boolean flag =true;
                for (Cluster cluster : clusterSet) {
                    if (cluster.getCenter().equals(point)) {
                        flag = false;
                    }
                }
                // 如果随机选取的点没有被选中过,则生成一个cluster
                if (flag) {
                    Cluster cluster =new Cluster(id, point);
                    clusterSet.add(cluster);
                    id++;
                }
            }
            return clusterSet;  
        }
     
        /**
         * 为每个点分配一个类!
         */
        public void cluster(Set<Cluster> clusterSet){
            // 计算每个点到K个中心的距离,并且为每个点标记类别号
            for (Point point : pointList) {
                float min_dis = Integer.MAX_VALUE;
                for (Cluster cluster : clusterSet) {
                    float tmp_dis = (float) Math.min(disC.getEuclideanDis(point, cluster.getCenter()), min_dis);
                    if (tmp_dis != min_dis) {
                        min_dis = tmp_dis;
                        point.setClusterId(cluster.getId());
                        point.setDist(min_dis);
                    }
                }
            }
            // 新清除原来所有的类中成员。把所有的点,分别加入每个类别
            for (Cluster cluster : clusterSet) {
                cluster.getMembers().clear();
                for (Point point : pointList) {
                    if (point.getClusterid()==cluster.getId()) {
                        cluster.addPoint(point);
                    }
                }
            }
        }
     
        /**
         * 计算每个类的中心位置!
         */
        public boolean calculateCenter(Set<Cluster> clusterSet) {
            boolean ifNeedIter = false; 
            for (Cluster cluster : clusterSet) {
                List<Point> point_list = cluster.getMembers();
                float[] sumAll =new float[len];
                // 所有点,对应各个维度进行求和
                for (int i = 0; i < len; i++) {
                    for (int j = 0; j < point_list.size(); j++) {
                        sumAll[i] += point_list.get(j).getlocalArray()[i];
                    }
                }
                // 计算平均值
                for (int i = 0; i < sumAll.length; i++) {
                    sumAll[i] = (float) sumAll[i]/point_list.size();
                }
                // 计算两个新、旧中心的距离,如果任意一个类中心移动的距离大于dis_diff则继续迭代。
                if(disC.getEuclideanDis(cluster.getCenter(), new Point(sumAll)) > disDiff){
                    ifNeedIter = true;
                }
                // 设置新的类中心位置
                cluster.setCenter(new Point(sumAll));
            }
            return ifNeedIter;
        }
     
        /**
         * 运行 k-means
         */
        public Set<Cluster> run() {
            Set<Cluster> clusterSet= chooseCenterCluster();
            boolean ifNeedIter = true; 
            while (ifNeedIter) {
                cluster(clusterSet);
                ifNeedIter = calculateCenter(clusterSet);
                iterRunTimes ++ ;
            }
            return clusterSet;
        }
     
        /**
         * 返回实际运行次数
         */
        public int getIterTimes() {
            return iterRunTimes;
        }
    }
    View Code

    4、Cluster.java

    package com.vue.demo.kmeans;
    
    import java.util.ArrayList;
    import java.util.List;
    
    public class Cluster {
        private int id;// 标识
        private Point center;// 中心
        private List<Point> members = new ArrayList<Point>();// 成员
     
        public Cluster(int id, Point center) {
            this.id = id;
            this.center = center;
        }
     
        public Cluster(int id, Point center, List<Point> members) {
            this.id = id;
            this.center = center;
            this.members = members;
        }
     
        public void addPoint(Point newPoint) {
            if (!members.contains(newPoint)){
                members.add(newPoint);
            }else{
                System.out.println("样本数据点 {"+newPoint.toString()+"} 已经存在!");
            }
        }
     
        public int getId() {
            return id;
        }
     
        public Point getCenter() {
            return center;
        }
     
        public void setCenter(Point center) {
            this.center = center;
        }
     
        public List<Point> getMembers() {
            return members;
        }
     
        @Override
        public String toString() {
            String toString = "Cluster 
    " + "Cluster_id=" + this.id + ", center:{" + this.center.toString()+"}";
            for (Point point : members) {
                toString+="
    "+point.toString();
            }
            return toString+"
    ";
        }
    }
    View Code

    5、DistanceCompute.java

    package com.vue.demo.kmeans;
    
    public class DistanceCompute {
        /**
         * 求欧式距离
         */
        public double getEuclideanDis(Point p1, Point p2) {
            double count_dis = 0;
            float[] p1_local_array = p1.getlocalArray();
            float[] p2_local_array = p2.getlocalArray();
     
            if (p1_local_array.length != p2_local_array.length) {
                throw new IllegalArgumentException("length of array must be equal!");
            }
     
            for (int i = 0; i < p1_local_array.length; i++) {
                count_dis += Math.pow(p1_local_array[i] - p2_local_array[i], 2);
            }
     
            return Math.sqrt(count_dis);
        }
    }
    View Code

    6、Kmeans.java

    package com.vue.demo.kmeans;
    
    /**
     * Created by li on 2019/1/6.
     */
    public class KMeans {
        private String variable1;
        private String variable2;
        private String variable3;
        private String groupbyfield;
    
        public String getVariable1() {
            return variable1;
        }
    
        public void setVariable1(String variable1) {
            this.variable1 = variable1;
        }
    
        public String getVariable2() {
            return variable2;
        }
    
        public void setVariable2(String variable2) {
            this.variable2 = variable2;
        }
    
        public String getVariable3() {
            return variable3;
        }
    
        public void setVariable3(String variable3) {
            this.variable3 = variable3;
        }
    
        public String getGroupbyfield() {
            return groupbyfield;
        }
    
        public void setGroupbyfield(String groupbyfield) {
            this.groupbyfield = groupbyfield;
        }
    }
    View Code
  • 相关阅读:
    问题 A: 走出迷宫(BFS)
    问题 A: 工作团队(并查集删点操作)
    刷题-力扣-989
    刷题-力扣-12
    刷题-力扣-628
    刷题-力扣-11
    刷题-力扣-1018
    刷题-力扣-9
    刷题-力扣-7
    刷题-力扣-6
  • 原文地址:https://www.cnblogs.com/ywjfx/p/12594342.html
Copyright © 2011-2022 走看看