zoukankan      html  css  js  c++  java
  • java编写ID3决策树

    说明:每个样本都会装入Data样本对象,决策树生成算法接收的是一个Array<Data>样本列表,所以构建测试数据时也要符合格式,最后生成的决策树是树的根节点,通过里面提供的showTree()方法可查看整个树结构,下面奉上源码。

    Data.java

    package ai.tree.data;
    
    import java.util.HashMap;
    
    /**
     * 样本类
     * @author ChenLuyang
     * @date 2019/2/21
     */
    public class Data implements Cloneable{
        /**
         * K是特征描述,V是特征值
         */
        private HashMap<String,String> feature = new HashMap<String, String>();
    
        /**
         * 该样本结论
         */
        private String result;
    
        public Data(HashMap<String,String> feature,String result){
            this.feature = feature;
            this.result = result;
        }
    
        public HashMap<String, String> getFeature() {
            return feature;
        }
    
        public String getResult() {
            return result;
        }
    
        private void setFeature(HashMap<String, String> feature) {
            this.feature = feature;
        }
    
        @Override
        public Data clone()
        {
            Data object=null;
            try {
                object = (Data) super.clone();
                object.setFeature((HashMap<String, String>) this.feature.clone());
            } catch (CloneNotSupportedException e) {
                e.printStackTrace();
            }
    
            return object;
        }
    }
    

      

    DecisionTree.java

    package ai.tree.algorithm;
    
    import ai.tree.data.Data;
    
    import java.math.BigDecimal;
    import java.util.*;
    
    /**
     * @author ChenLuyang
     * @date 2019/2/21
     */
    public class DecisionTree {
        /**
         * 递归构建决策树
         *
         * @param dataList 样本集合
         * @return ai.tree.algorithm.DecisionTree.TreeNode 使用传入样本构建的决策节点
         * @author ChenLuyang
         * @date 2019/2/21 16:05
         */
        public TreeNode createTree(List<Data> dataList) {
            //创建当前节点
            TreeNode<String, String, String> nowTreeNode = new TreeNode<String, String, String>();
            //当前节点的各个分支节点
            Map<String, TreeNode> featureDecisionMap = new HashMap<String, TreeNode>();
    
            //统计当前样本集中所有的分类结果
            Set<String> resultSet = new HashSet<String>();
            for (Data data :
                    dataList) {
                resultSet.add(data.getResult());
            }
    
            //如果当前样本集只有一种类别,则表示不用分类了,返回当前节点
            if (resultSet.size() == 1) {
                String resultClassify = resultSet.iterator().next();
    
                nowTreeNode.setResultNode(resultClassify);
    
                return nowTreeNode;
            }
    
            //如果数据集中特征为空,则选择整个集合中出现次数最多的分类,作为分类结果
            if (dataList.get(0).getFeature().size() == 0) {
                Map<String, Integer> countMap = new HashMap<String, Integer>();
                for (Data data :
                        dataList) {
                    Integer num = countMap.get(data.getResult());
                    if (num == null) {
                        countMap.put(data.getResult(), 1);
                    } else {
                        countMap.put(data.getResult(), num + 1);
                    }
                }
    
                String tmpResult = "";
                Integer tmpNum = 0;
                for (String res :
                        countMap.keySet()) {
                    if (countMap.get(res) > tmpNum) {
                        tmpNum = countMap.get(res);
                        tmpResult = res;
                    }
                }
    
                nowTreeNode.setResultNode(tmpResult);
    
                return nowTreeNode;
            }
    
            //寻找当前最优分类
            String bestLabel = chooseBestFeatureToSplit(dataList);
    
            //提取最优特征的所有可能值
            Set<String> bestLabelInfoSet = new HashSet<String>();
            for (Data data :
                    dataList) {
                bestLabelInfoSet.add(data.getFeature().get(bestLabel));
            }
    
            //使用最优特征的各个特征值进行分类
            for (String labelInfo :
                    bestLabelInfoSet) {
                for (Data data :
                        dataList) {
                }
                List<Data> branchDataList = splitDataList(dataList, bestLabel, labelInfo);
    
                //最优特征下该特征值的节点
                TreeNode branchTreeNode = createTree(branchDataList);
                featureDecisionMap.put(labelInfo, branchTreeNode);
            }
    
            nowTreeNode.setDecisionNode(bestLabel, featureDecisionMap);
    
            return nowTreeNode;
        }
    
        /**
         * 计算传入数据集中的最优分类特征
         *
         * @param dataList
         * @return int 最优分类特征的描述
         * @author ChenLuyang
         * @date 2019/2/21 14:12
         */
        public String chooseBestFeatureToSplit(List<Data> dataList) {
            //目前数据集中的特征集合
            Set<String> futureSet = dataList.get(0).getFeature().keySet();
    
            //未分类时的熵
            BigDecimal baseEntropy = calcShannonEnt(dataList);
    
            //熵差
            BigDecimal bestInfoGain = new BigDecimal("0");
            //最优特征
            String bestFeature = "";
    
            //按照各特征分类
            for (String future :
                    futureSet) {
                //该特征分类后的熵
                BigDecimal futureEntropy = new BigDecimal("0");
    
                //该特征的所有特征值去重集合
                Set<String> futureInfoSet = new HashSet<String>();
                for (Data data :
                        dataList) {
                    futureInfoSet.add(data.getFeature().get(future));
                }
    
                //按照该特征的特征值一一分类
                for (String futureInfo :
                        futureInfoSet) {
                    List<Data> splitResultDataList = splitDataList(dataList, future, futureInfo);
    
                    //分类后样本数占总样本数的比例
                    BigDecimal tmpProb = new BigDecimal(splitResultDataList.size() + "").divide(new BigDecimal(dataList.size() + ""), 5, BigDecimal.ROUND_HALF_DOWN);
    
                    //所占比例乘以分类后的样本熵,然后再进行熵的累加
                    futureEntropy = futureEntropy.add(tmpProb.multiply(calcShannonEnt(splitResultDataList)));
                }
    
                BigDecimal subEntropy = baseEntropy.subtract(futureEntropy);
    
                if (subEntropy.compareTo(bestInfoGain) >= 0) {
                    bestInfoGain = subEntropy;
                    bestFeature = future;
                }
            }
    
            return bestFeature;
        }
    
        /**
         * 计算传入样本集的熵值
         *
         * @param dataList 样本集
         * @return java.math.BigDecimal 熵
         * @author ChenLuyang
         * @date 2019/2/22 9:41
         */
        public BigDecimal calcShannonEnt(List<Data> dataList) {
            //样本总数
            BigDecimal sumEntries = new BigDecimal(dataList.size() + "");
            //香农熵
            BigDecimal shannonEnt = new BigDecimal("0");
            //统计各个分类结果的样本数量
            Map<String, Integer> resultCountMap = new HashMap<String, Integer>();
            for (Data data :
                    dataList) {
                Integer dataResultCount = resultCountMap.get(data.getResult());
                if (dataResultCount == null) {
                    resultCountMap.put(data.getResult(), 1);
                } else {
                    resultCountMap.put(data.getResult(), dataResultCount + 1);
                }
            }
    
            for (String resultCountKey :
                    resultCountMap.keySet()) {
                BigDecimal resultCountValue = new BigDecimal(resultCountMap.get(resultCountKey).toString());
    
                BigDecimal prob = resultCountValue.divide(sumEntries, 5, BigDecimal.ROUND_HALF_DOWN);
                shannonEnt = shannonEnt.subtract(prob.multiply(new BigDecimal(Math.log(prob.doubleValue()) / Math.log(2) + "")));
            }
    
            return shannonEnt;
        }
    
        /**
         * 根据某个特征的特征值,进行样本数据的划分,将划分后的样本数据集返回
         *
         * @param dataList 待划分的样本数据集
         * @param future   筛选的特征依据
         * @param info     筛选的特征值依据
         * @return java.util.List<ai.tree.data.Data> 按照指定特征值分类后的数据集
         * @author ChenLuyang
         * @date 2019/2/21 18:26
         */
        public List<Data> splitDataList(List<Data> dataList, String future, String info) {
            List<Data> resultDataList = new ArrayList<Data>();
            for (Data data :
                    dataList) {
                if (data.getFeature().get(future).equals(info)) {
                    Data newData = (Data) data.clone();
                    newData.getFeature().remove(future);
                    resultDataList.add(newData);
                }
            }
    
            return resultDataList;
        }
    
        /**
         * L:每一个特征的描述信息的类型
         * F:特征的类型
         * R:最终分类结果的类型
         */
        public class TreeNode<L, F, R> {
            /**
             * 该节点的最优特征的描述信息
             */
            private L label;
    
            /**
             * 根据不同的特征作出响应的决定。
             * K为特征值,V为该特征值作出的决策节点
             */
            private Map<F, TreeNode> featureDecisionMap;
    
            /**
             * 是否为最终分类节点
             */
            private boolean isFinal;
    
            /**
             * 最终分类结果信息
             */
            private R resultClassify;
    
            /**
             * 设置叶子节点
             *
             * @param resultClassify 最终分类结果
             * @return void
             * @author ChenLuyang
             * @date 2019/2/22 18:31
             */
            public void setResultNode(R resultClassify) {
                this.isFinal = true;
                this.resultClassify = resultClassify;
            }
    
            /**
             * 设置分支节点
             *
             * @param label              当前分支节点的描述信息(特征)
             * @param featureDecisionMap 当前分支节点的各个特征值,与其对应的子节点
             * @return void
             * @author ChenLuyang
             * @date 2019/2/22 18:31
             */
            public void setDecisionNode(L label, Map<F, TreeNode> featureDecisionMap) {
                this.isFinal = false;
                this.label = label;
                this.featureDecisionMap = featureDecisionMap;
            }
    
            /**
             * 展示当前节点的树结构
             *
             * @return void
             * @author ChenLuyang
             * @date 2019/2/22 16:54
             */
            public String showTree() {
                HashMap<String, String> treeMap = new HashMap<String, String>();
                if (isFinal) {
                    String key = "result";
                    R value = resultClassify;
                    treeMap.put(key, value.toString());
                } else {
                    String key = label.toString();
                    HashMap<F, String> showFutureMap = new HashMap<F, String>();
                    for (F f :
                            featureDecisionMap.keySet()) {
                        showFutureMap.put(f, featureDecisionMap.get(f).showTree());
                    }
                    String value = showFutureMap.toString();
    
                    treeMap.put(key, value);
                }
    
                return treeMap.toString();
            }
    
            public L getLabel() {
                return label;
            }
    
            public Map<F, TreeNode> getFeatureDecisionMap() {
                return featureDecisionMap;
            }
    
            public R getResultClassify() {
                return resultClassify;
            }
    
            public boolean getFinal() {
                return isFinal;
            }
        }
    }
    

      

    Start.java

    package ai.tree.algorithm;
    
    import ai.tree.data.Data;
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    
    /**
     * @author ChenLuyang
     * @date 2019/2/22
     */
    public class Start {
        /**
         * 构建测试样本集,测试样本如下:
         样本特征:{头发长短=短发, 身材=胖, 是否戴眼镜=有眼镜} 分类:男
         样本特征:{头发长短=长发, 身材=瘦, 是否戴眼镜=有眼镜} 分类:女
         样本特征:{头发长短=短发, 身材=胖, 是否戴眼镜=有眼镜} 分类:女
         样本特征:{头发长短=长发, 身材=胖, 是否戴眼镜=没眼镜} 分类:男
         样本特征:{头发长短=短发, 身材=瘦, 是否戴眼镜=没眼镜} 分类:男
         样本特征:{头发长短=长发, 身材=瘦, 是否戴眼镜=有眼镜} 分类:女
         样本特征:{头发长短=长发, 身材=胖, 是否戴眼镜=有眼镜} 分类:男
         * @author ChenLuyang
         * @date 2019/2/21 15:34
         * @return java.util.List<ai.tree.data.DecisionTreeTestData.Data> 样本集
         */
        public static List<Data> createDataList(){
            /**
             * 样本特征描述
             * @author ChenLuyang
             * @date 2019/2/22 18:55
             * @return java.util.List<ai.tree.data.Data>
             */
            String[] labels = new String[]{"是否戴眼镜", "头发长短", "身材"};
    
            List<Data> dataList = new ArrayList<Data>();
    
            HashMap<String,String> feature1 = new HashMap<String, String>();
            feature1.put(labels[0],"有眼镜");
            feature1.put(labels[1].toString(),"短发");
            feature1.put(labels[2].toString(),"胖");
            dataList.add(new Data(feature1,"男"));
    
            HashMap<String,String> feature2 = new HashMap<String, String>();
            feature2.put(labels[0],"有眼镜");
            feature2.put(labels[1],"长发");
            feature2.put(labels[2],"瘦");
            dataList.add(new Data(feature2,"女"));
    
            HashMap<String,String> feature3 = new HashMap<String, String>();
            feature3.put(labels[0],"有眼镜");
            feature3.put(labels[1],"短发");
            feature3.put(labels[2],"胖");
            dataList.add(new Data(feature3,"女"));
    
            HashMap<String,String> feature4 = new HashMap<String, String>();
            feature4.put(labels[0],"没眼镜");
            feature4.put(labels[1],"长发");
            feature4.put(labels[2],"胖");
            dataList.add(new Data(feature4,"男"));
    
            HashMap<String,String> feature5 = new HashMap<String, String>();
            feature5.put(labels[0],"没眼镜");
            feature5.put(labels[1],"短发");
            feature5.put(labels[2],"瘦");
            dataList.add(new Data(feature5,"男"));
    
            HashMap<String,String> feature6 = new HashMap<String, String>();
            feature6.put(labels[0],"有眼镜");
            feature6.put(labels[1],"长发");
            feature6.put(labels[2],"瘦");
            dataList.add(new Data(feature6,"女"));
    
            HashMap<String,String> feature7 = new HashMap<String, String>();
            feature7.put(labels[0],"有眼镜");
            feature7.put(labels[1],"长发");
            feature7.put(labels[2],"胖");
            dataList.add(new Data(feature7,"男"));
    
            return dataList;
        }
    
        public static void main(String[] args) {
            DecisionTree decisionTree = new DecisionTree();
    
            //使用测试样本生成决策树
            DecisionTree.TreeNode tree = decisionTree.createTree(createDataList());
    
            //展示决策树
            System.out.println(tree.showTree());
        }
    }
    

      

    生成树结构:{是否戴眼镜={没眼镜={result=男}, 有眼镜={身材={胖={头发长短={长发={result=男}, 短发={result=女}}}, 瘦={result=女}}}}}

  • 相关阅读:
    POJ 1811 Prime Test 素性测试 分解素因子
    sysbench的安装与使用
    电脑中已有VS2005和VS2010安装.NET3.5失败的解决方案
    I.MX6 show battery states in commandLine
    RPi 2B Raspbian system install
    I.MX6 bq27441 driver porting
    I.MX6 隐藏电池图标
    I.MX6 Power off register hacking
    I.MX6 Goodix GT9xx touchscreen driver porting
    busybox filesystem httpd php-5.5.31 sqlite3 webserver
  • 原文地址:https://www.cnblogs.com/red-code/p/10420107.html
Copyright © 2011-2022 走看看