zoukankan      html  css  js  c++  java
  • 【决策树】— C4.5算法建立决策树JAVA练习

    以下程序是我练习写的,不一定正确也没做存储优化。有问题请留言交流。转载请挂连接。

    当前的属性为:age income student credit_rating

    当前的数据集为(最后一列是TARGET_VALUE):

    ---------------------------------

    youth     high   no   fair      no
    youth     high   no   excellent   no
    middle_aged   high   no   fair     yes
    senior     low    yes  fair     yes
    senior     low    yes  excellent   no
    middle_aged   low    yes  excellent   yes
    youth     medium  no   fair     no
    youth     low     yes  fair     yes
    senior     medium  yes    fair     yes
    youth     medium  yes    excellent   yes
    middle_aged   high   yes  fair        yes
    senior     medium  no     excellent   no
    ---------------------------------

    C4.5建立树类

    package C45Test;
    
    import java.util.ArrayList;
    import java.util.List;
    import java.util.Map;
    
    public class DecisionTree {
    
        public TreeNode createDT(List<ArrayList<String>> data,List<String> attributeList){
            
            System.out.println("当前的DATA为");
            for(int i=0;i<data.size();i++){
                ArrayList<String> temp = data.get(i);
                for(int j=0;j<temp.size();j++){
                    System.out.print(temp.get(j)+ " ");
                }
                System.out.println();
            }
            System.out.println("---------------------------------");
            System.out.println("当前的ATTR为");
            for(int i=0;i<attributeList.size();i++){
                System.out.print(attributeList.get(i)+ " ");
            }
            System.out.println();
            System.out.println("---------------------------------");
            TreeNode node = new TreeNode();
            String result = InfoGain.IsPure(InfoGain.getTarget(data));
            if(result != null){
                node.setNodeName("leafNode");
                node.setTargetFunValue(result);
                return node;
            }
            if(attributeList.size() == 0){
                node.setTargetFunValue(result);
                return node;
            }else{
                InfoGain gain = new InfoGain(data,attributeList);
                double maxGain = 0.0;
                int attrIndex = -1;
                for(int i=0;i<attributeList.size();i++){
                    double tempGain = gain.getGainRatio(i);
                    if(maxGain < tempGain){
                        maxGain = tempGain;
                        attrIndex = i;
                    }
                }
                System.out.println("选择出的最大增益率属性为: " + attributeList.get(attrIndex));
                node.setAttributeValue(attributeList.get(attrIndex));
                List<ArrayList<String>> resultData = null;
                Map<String,Long> attrvalueMap = gain.getAttributeValue(attrIndex);
                for(Map.Entry<String, Long> entry : attrvalueMap.entrySet()){
                    resultData = gain.getData4Value(entry.getKey(), attrIndex);
                    TreeNode leafNode = null;
                    System.out.println("当前为"+attributeList.get(attrIndex)+"的"+entry.getKey()+"分支。");
                    if(resultData.size() == 0){
                        leafNode = new TreeNode();
                        leafNode.setNodeName(attributeList.get(attrIndex));
                        leafNode.setTargetFunValue(result);
                        leafNode.setAttributeValue(entry.getKey());
                    }else{
                        for (int j = 0; j < resultData.size(); j++) {
                            resultData.get(j).remove(attrIndex);
                        }
                        ArrayList<String> resultAttr = new ArrayList<String>(attributeList);
                        resultAttr.remove(attrIndex);
                        leafNode = createDT(resultData,resultAttr);
                    }
                    node.getChildTreeNode().add(leafNode);
                    node.getPathName().add(entry.getKey());
                }
            }
            return node;
        }
        
        class TreeNode{
            
            private String attributeValue;
            private List<TreeNode> childTreeNode;
            private List<String> pathName;
            private String targetFunValue;
            private String nodeName;
            
            public TreeNode(String nodeName){
                
                this.nodeName = nodeName;
                this.childTreeNode = new ArrayList<TreeNode>();
                this.pathName = new ArrayList<String>();
            }
            
            public TreeNode(){
                this.childTreeNode = new ArrayList<TreeNode>();
                this.pathName = new ArrayList<String>();
            }
    
            public String getAttributeValue() {
                return attributeValue;
            }
    
            public void setAttributeValue(String attributeValue) {
                this.attributeValue = attributeValue;
            }
    
            public List<TreeNode> getChildTreeNode() {
                return childTreeNode;
            }
    
            public void setChildTreeNode(List<TreeNode> childTreeNode) {
                this.childTreeNode = childTreeNode;
            }
    
            public String getTargetFunValue() {
                return targetFunValue;
            }
    
            public void setTargetFunValue(String targetFunValue) {
                this.targetFunValue = targetFunValue;
            }
    
            public String getNodeName() {
                return nodeName;
            }
    
            public void setNodeName(String nodeName) {
                this.nodeName = nodeName;
            }
    
            public List<String> getPathName() {
                return pathName;
            }
    
            public void setPathName(List<String> pathName) {
                this.pathName = pathName;
            }
            
        }
    }

    增益率计算类(取log的时候底用的是e,没用2

    package C45Test;
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.HashSet;
    import java.util.Iterator;
    import java.util.List;
    import java.util.Map;
    import java.util.Set;
    
    //C 4.5 实现
    public class InfoGain {
        
        private List<ArrayList<String>> data;
        private List<String> attribute;
        
        public InfoGain(List<ArrayList<String>> data,List<String> attribute){
            
            this.data = new ArrayList<ArrayList<String>>();
            for(int i=0;i<data.size();i++){
                List<String> temp = data.get(i);
                ArrayList<String> t = new ArrayList<String>();
                for(int j=0;j<temp.size();j++){
                    t.add(temp.get(j));
                }
                this.data.add(t);
            }
            
            this.attribute = new ArrayList<String>();
            for(int k=0;k<attribute.size();k++){
                this.attribute.add(attribute.get(k));
            }
            /*this.data = data;
            this.attribute = attribute;*/
        }
        
        //获得熵
        public double getEntropy(){
            
            Map<String,Long> targetValueMap = getTargetValue();
            Set<String> targetkey = targetValueMap.keySet();
            double entropy = 0.0;
            for(String key : targetkey){
                double p = MathUtils.div((double)targetValueMap.get(key), (double)data.size());
                entropy += (-1) * p * Math.log(p);
            }
            return entropy;
        }
        
        //获得InfoA
        public double getInfoAttribute(int attributeIndex){
            
            Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex);
            double infoA = 0.0;
            for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){
                int size = data.size();
                double attributeP = MathUtils.div((double)entry.getValue() , (double) size);
                Map<String,Long> targetValueMap = getAttributeValueTargetValue(entry.getKey(),attributeIndex);
                long totalCount = 0L;
                for(Map.Entry<String, Long> entryValue :targetValueMap.entrySet()){
                    totalCount += entryValue.getValue(); 
                }
                double valueSum = 0.0;
                for(Map.Entry<String, Long> entryTargetValue : targetValueMap.entrySet()){
                     double p = MathUtils.div((double)entryTargetValue.getValue(), (double)totalCount);
                     valueSum += Math.log(p) * p;
                }
                infoA += (-1) * attributeP * valueSum;
            }
            return infoA;
            
        }
        
        //得到属性值在决策空间的比例
        public Map<String,Long> getAttributeValueTargetValue(String attributeName,int attributeIndex){
            
            Map<String,Long> targetValueMap = new HashMap<String,Long>();
            Iterator<ArrayList<String>> iterator = data.iterator();
            while(iterator.hasNext()){
                List<String> tempList = iterator.next();
                if(attributeName.equalsIgnoreCase(tempList.get(attributeIndex))){
                    int size = tempList.size();
                    String key = tempList.get(size - 1);
                    Long value = targetValueMap.get(key);
                    targetValueMap.put(key, value != null ? ++value :1L);
                }
            }
            return targetValueMap;
        }
        
        //得到属性在决策空间上的数量
        public Map<String,Long> getAttributeValue(int attributeIndex){
            
            Map<String,Long> attributeValueMap = new HashMap<String,Long>();
            for(ArrayList<String> note : data){
                String key = note.get(attributeIndex);
                Long value = attributeValueMap.get(key);
                attributeValueMap.put(key, value != null ? ++value :1L);
            }
            return attributeValueMap;
            
        }
        
        public List<ArrayList<String>> getData4Value(String attrValue,int attrIndex){
            
            List<ArrayList<String>> resultData = new ArrayList<ArrayList<String>>();
            Iterator<ArrayList<String>> iterator = data.iterator();
            for(;iterator.hasNext();){
                ArrayList<String> templist = iterator.next();
                if(templist.get(attrIndex).equalsIgnoreCase(attrValue)){
                    ArrayList<String> temp = (ArrayList<String>) templist.clone();
                    resultData.add(temp);
                }
            }
            return resultData;
        }
        
        //获得增益率
        public double getGainRatio(int attributeIndex){
            return MathUtils.div(getGain(attributeIndex), getSplitInfo(attributeIndex));
        }
        
        //获得增益量
        public double getGain(int attributeIndex){
            return getEntropy() - getInfoAttribute(attributeIndex);
        }
        
        //得到惩罚因子
        public double getSplitInfo(int attributeIndex){
            
            Map<String,Long> attributeValueMap = getAttributeValue(attributeIndex);
            double splitA = 0.0;
            for(Map.Entry<String, Long> entry : attributeValueMap.entrySet()){
                int size = data.size();
                double attributeP = MathUtils.div((double)entry.getValue() , (double) size);
                splitA += attributeP * Math.log(attributeP) * (-1);
            }
            return splitA;
        }
        
        //得到目标函数在当前集合范围内的离散的值
        public Map<String,Long> getTargetValue(){
            
            Map<String,Long> targetValueMap = new HashMap<String,Long>();
            Iterator<ArrayList<String>> iterator = data.iterator();
            while(iterator.hasNext()){
                List<String> tempList = iterator.next();
                String key = tempList.get(tempList.size() - 1);
                Long value = targetValueMap.get(key);
                targetValueMap.put(key, value != null ? ++value : 1L);
            }
            return targetValueMap;
        }
        
        //获得TARGET值
        public static List<String> getTarget(List<ArrayList<String>> data){
            
            List<String> list = new ArrayList<String>();
            for(ArrayList<String> temp : data){
                int index = temp.size() -1;
                String value = temp.get(index);
                list.add(value);
            }
            return list;
        }
        
        //判断当前纯度是否100%
        public static String IsPure(List<String> list){
            
            Set<String> set = new HashSet<String>();
            for(String name :list){
                set.add(name);
            }
            if(set.size() > 1) return null;
            Iterator<String> iterator = set.iterator();
            return iterator.next();
        }
        
    
    }

    测试类,数据集读取以上的分别放到2个List中。

    package C45Test;
    
    import java.util.ArrayList;
    import java.util.List;
    
    import C45Test.DecisionTree.TreeNode;
    
    public class MainC45 {
    
        private static final List<ArrayList<String>> dataList = new ArrayList<ArrayList<String>>();
        private static final List<String> attributeList = new ArrayList<String>();
        
        public static void main(String args[]){
            
            DecisionTree dt = new DecisionTree();
            TreeNode node = dt.createDT(configData(),configAttribute());
            System.out.println();
        }
    }

    大数运算工具类

    package C45Test;
    import java.math.BigDecimal;
    
    public abstract class MathUtils {
        
        //默认余数长度
        private static final int DIV_SCALE = 10;
        
        //受限于DOUBLE长度
        public static double add(double value1,double value2){
            
            BigDecimal big1 = new BigDecimal(String.valueOf(value1));
            BigDecimal big2 = new BigDecimal(String.valueOf(value2));
            return big1.add(big2).doubleValue();
        }
        
        //大数加法
        public static double add(String value1,String value2){
            
            BigDecimal big1 = new BigDecimal(value1);
            BigDecimal big2 = new BigDecimal(value2);
            return big1.add(big2).doubleValue();
        }
        
        public static double div(double value1,double value2){
            
            BigDecimal big1 = new BigDecimal(String.valueOf(value1));
            BigDecimal big2 = new BigDecimal(String.valueOf(value2));
            return big1.divide(big2,DIV_SCALE,BigDecimal.ROUND_HALF_UP).doubleValue();
        }
        
        public static double mul(double value1,double value2){
            
            BigDecimal big1 = new BigDecimal(String.valueOf(value1));
            BigDecimal big2 = new BigDecimal(String.valueOf(value2));
            return big1.multiply(big2).doubleValue();
        }
        
        public static double sub(double value1,double value2){
            
            BigDecimal big1 = new BigDecimal(String.valueOf(value1));
            BigDecimal big2 = new BigDecimal(String.valueOf(value2));
            return big1.subtract(big2).doubleValue();
        }
        
        public static double returnMax(double value1, double value2) {
            
            BigDecimal big1 = new BigDecimal(value1);
            BigDecimal big2 = new BigDecimal(value2);
            return big1.max(big2).doubleValue();
        }
    }
  • 相关阅读:
    软件上线标准
    rap使用手册
    微服务
    什么是集合?
    什么是maven?maven中的pom文件是做什么的?
    什么是连接池?
    架构
    什么是反射?
    产品
    描述下fastJSON,jackson等等的技术
  • 原文地址:https://www.cnblogs.com/lixusign/p/2548124.html
Copyright © 2011-2022 走看看