zoukankan      html  css  js  c++  java
  • 决策树(2)

    package TreeStructure;
    
    import java.util.ArrayList;
    import java.util.List;
    
    public class testClass {
        public static void main(String[] args) {
            double [][]exercise = {{1,1,0,0},{1,3,1,1},{3,2,0,0},{3,2,1,10},{3,2,1,10},{3,2,1,10},{2,2,1,1},{3,2,1,9},{2,3,0,1},{2,1,0,0},{3,2,0,1},{2,1,0,1},{1,1,0,1}};
            String []Attribute = {"weather","thin","cloth","target"};
            int []index = {1,0,2,3};
            double [][]exerciseData = new double[exercise.length][];
            for(int i = 0;i<exerciseData.length;i++){
                exerciseData[i] = new double[exercise[i].length];
                for(int j = 0;j<exerciseData[i].length;j++){
                    exerciseData[i][j] = exercise[i][index[j]];
                }
            }
            
            
            for(int i = 0;i<exerciseData.length;i++){
                for(int j = 0;j<exerciseData[i].length;j++){
                    System.out.print("  "+exerciseData[i][j]);
                }
                System.out.println();
            }
            
            DecisionTree dt = new DecisionTree();
            List<ArrayList<String>> data = new ArrayList<ArrayList<String>>();
            for(int i=0;i<exerciseData.length;i++){
                ArrayList<String> t = new ArrayList<String>();
                for(int j=0;j<exerciseData[i].length;j++){
                    t.add(exerciseData[i][j]+"");
                }
                data.add(t);
            }
            
            List<String>attribute = new ArrayList<String>();
            for(int k=0;k<Attribute.length;k++){
                attribute.add(Attribute[k]);
            }
            TreeNode n =null;
            TreeNode node = dt.createDT(data,attribute,n);
            double[]dataExercise = {2,3};
            List list = new ArrayList();
            for(int i = 0;i<dataExercise.length;i++){
                list.add(dataExercise[i]);
            }
            
            node.traverse(list);
            
            System.out.println();
        }
        
    }
    package TreeStructure;
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.HashSet;
    import java.util.Iterator;
    import java.util.List;
    import java.util.Map;
    import java.util.Set;
    
    public class DecisionTree {
        
        public TreeNode createDT(List<ArrayList<String>> data,List<String> attributeList,TreeNode node){
            
            System.out.println("当前的DATA为");
             for(int i=0;i<data.size();i++){
                    ArrayList<String> temp = data.get(i);
                    for(int j=0;j<temp.size();j++){
                        System.out.print(temp.get(j)+ " ");
                    }
                    System.out.println();
                }
                System.out.println("---------------------------------");
                System.out.println("当前的ATTR为");
                for(int i=0;i<attributeList.size();i++){
                    System.out.print(attributeList.get(i)+ " ");
                }
                System.out.println();
                System.out.println("---------------------------------");
                //String result = InfoGain.IsPure(InfoGain.getTarget(data));
                //System.out.println("***************"+result);
                
                if(node==null){
                    node = new TreeNode();
                    node.setAttributeValue("start");
                    node.setNodeName("start");
                    
                }
                
                if(attributeList.size() == 1){
                    
                    int num = data.size();
                    for(int i = 0;i<num;i++){
                    TreeNode leafNode = new TreeNode();
                    leafNode.setAttributeValue(data.get(i).get(0));
                    leafNode.setNodeName("target");
                    node.getChildTreeNode().add(leafNode);
                    }
                    return node;
                    
                }else{
                    
                    System.out.println("选择出的最大增益率属性为: " + attributeList.get(0));
                    //node.setAttributeValue(attributeList.get(0));
                    List<ArrayList<String>> resultData = null;
                    InfoGain gain = new InfoGain(data,attributeList);
                    
                    Map<String,Long> attrvalueMap = gain.getAttributeValue(0);
                    
                    for(Map.Entry<String, Long> entry : attrvalueMap.entrySet()){
                        resultData = gain.getData4Value(entry.getKey(), 0);
                        TreeNode leafNode = new TreeNode();
                        leafNode.setAttributeValue(entry.getKey());
                        leafNode.setNodeName(attributeList.get(0));
                        
                        node.getChildTreeNode().add(leafNode);
                        
                        System.out.println("当前为"+attributeList.get(0)+"的"+entry.getKey()+"分支。");
                        for (int j = 0; j < resultData.size(); j++) {
                            resultData.get(j).remove(0);
                        }
                        ArrayList<String> resultAttr = new ArrayList<String>(attributeList);
                        resultAttr.remove(0);
                        createDT(resultData,resultAttr,leafNode);            
                        }
                }
               
                return node;
                }
            }
                
                
                
                
                
                
                
                
                
        
        
    
        
    package TreeStructure;
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.HashSet;
    import java.util.Iterator;
    import java.util.List;
    import java.util.Map;
    import java.util.Set;
    
    public class InfoGain {
        private List<ArrayList<String>> data;
        private List<String> attribute;
        
        
    public InfoGain(List<ArrayList<String>> data,List<String> attribute){
            
            this.data = new ArrayList<ArrayList<String>>();
            for(int i=0;i<data.size();i++){
                List<String> temp = data.get(i);
                ArrayList<String> t = new ArrayList<String>();
                for(int j=0;j<temp.size();j++){
                    t.add(temp.get(j));
                }
                this.data.add(t);
            }
            
            this.attribute = new ArrayList<String>();
            for(int k=0;k<attribute.size();k++){
                this.attribute.add(attribute.get(k));
            }
            /*this.data = data;
            this.attribute = attribute;*/
        }
    public  Map<String,Long> getAttributeValue(int attributeIndex){
            
            Map<String,Long> attributeValueMap = new HashMap<String,Long>();
            for(ArrayList<String> note : data){
                String key = note.get(attributeIndex);
                Long value = attributeValueMap.get(key);
                attributeValueMap.put(key, value != null ? ++value :1L);
            }
            return attributeValueMap;
            
        }
        
        public List<ArrayList<String>> getData4Value(String attrValue,int attrIndex){
            
            List<ArrayList<String>> resultData = new ArrayList<ArrayList<String>>();
            Iterator<ArrayList<String>> iterator = data.iterator();
            for(;iterator.hasNext();){
                ArrayList<String> templist = iterator.next();
                if(templist.get(attrIndex).equalsIgnoreCase(attrValue)){
                    ArrayList<String> temp = (ArrayList<String>) templist.clone();
                    resultData.add(temp);
                }
            }
            return resultData;
        }
    public static List<String> getTarget(List<ArrayList<String>> data){
            
            List<String> list = new ArrayList<String>();
            for(ArrayList<String> temp : data){
                int index = temp.size()-1 ;
                if(index == -1){
                    break;
                }
                String value = temp.get(index);
                list.add(value);
            }
            return list;
        }
        
        //判断当前纯度是否100%
        public static String IsPure(List<String> list){
            
           
            
            return list.get(0);
        }
        
    }
    package TreeStructure;
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.List;
    import java.util.Map;
    
    
     class TreeNode{
    
    private String attributeValue;
            private List<TreeNode> childTreeNode;
            private List<String> pathName;
            private String targetFunValue;
            private String nodeName;
            
            public TreeNode(String nodeName){
                
                this.nodeName = nodeName;
                this.childTreeNode = new ArrayList<TreeNode>();
                this.pathName = new ArrayList<String>();
            }
            
            public TreeNode(){
                this.childTreeNode = new ArrayList<TreeNode>();
                this.pathName = new ArrayList<String>();
            }
    
            public String getAttributeValue() {
                return attributeValue;
            }
    
            public void setAttributeValue(String attributeValue) {
                this.attributeValue = attributeValue;
            }
    
            public List<TreeNode> getChildTreeNode() {
                return childTreeNode;
            }
    
            public void setChildTreeNode(List<TreeNode> childTreeNode) {
                this.childTreeNode = childTreeNode;
            }
    
            public String getTargetFunValue() {
                return targetFunValue;
            }
    
            public void setTargetFunValue(String targetFunValue) {
                this.targetFunValue = targetFunValue;
            }
    
            public String getNodeName() {
                return nodeName;
            }
    
            public void setNodeName(String nodeName) {
                this.nodeName = nodeName;
            }
    
            public List<String> getPathName() {
                return pathName;
            }
    
            public void setPathName(List<String> pathName) {
                this.pathName = pathName;
            }
            
            public void traverse() {  
                System.out.println(this.getNodeName()+":   "+this.getAttributeValue());
                int childNumber = this.childTreeNode.size(); 
                System.out.println(childNumber);
                for (int i = 0; i < childNumber; i++) {  
                    TreeNode child = this.childTreeNode.get(i);  
                    child.traverse();  
                }  
            }  
            
            
            public List getTarget(TreeNode node){
                List a = new ArrayList();;
                int childNum = node.getChildTreeNode().size();
                if(node.childTreeNode.get(0).childTreeNode.size()==0){//表示node孩子的孩子为空,即node下一层为目标层
                    for(int i = 0;i<childNum;i++){
                        a.add(node.getChildTreeNode().get(i).getAttributeValue());
                        
                    }
                    
                }else{
                    for(int i = 0;i<childNum;i++){
                        a.addAll(getTarget(node.getChildTreeNode().get(i)));
                    }
                }
                return a;
            }
            public void traverse(List list) {
                if(list.size()==0){
                    List target = getTarget(this);
    //                int childlistNumber = this.childTreeNode.size(); 
    //                List a = new ArrayList();
    //                for(int i = 0;i<childlistNumber;i++){
    //                TreeNode child = this.childTreeNode.get(i);
    //                a.add(child.getAttributeValue());
    //                }
                    List b = new ArrayList();
    //                Map result = new HashMap();
                    for(int i = 0;i<target.size();i++){
                        if(!b.contains(target.get(i))){
                        b.add(target.get(i));
                        }
                    }
                    int []count = new int [b.size()];
                    for(int i = 0;i<b.size();i++){
                        
                        for(int j = 0;j<target.size();j++){
                            if(b.get(i).equals(target.get(j))){
                                count[i] = count[i]+1;
                            }
                        }
                        System.out.println(b.get(i)+"的数量是:   "+count[i]);
                    }
                    int maxIndex = 0;
                    for(int i = 1;i<count.length;i++){
                        if(count[maxIndex]<count[i]){
                            maxIndex = i;
                        }
                    }
                    System.out.println("选择"+b.get(maxIndex)+"为最终决策");
                    
                    
                    
                    
                }else{
                List a = new ArrayList();
                double temp = (Double)list.get(0);
                int childlistNumber = this.childTreeNode.size(); 
                System.out.println(childlistNumber);
                for(int i = 0;i<childlistNumber;i++){
                    TreeNode child = this.childTreeNode.get(i);  
                    double tempchild = Double.valueOf(child.getAttributeValue());
                    if(temp==tempchild){
                        System.out.println(child.getNodeName()+":   "+child.getAttributeValue());
                        list.remove(0);
                        child.traverse(list);
                    }
                }
                }
            }
     }
            
        
     
  • 相关阅读:
    257. Binary Tree Paths
    324. Wiggle Sort II
    315. Count of Smaller Numbers After Self
    350. Intersection of Two Arrays II
    295. Find Median from Data Stream
    289. Game of Life
    287. Find the Duplicate Number
    279. Perfect Squares
    384. Shuffle an Array
    E
  • 原文地址:https://www.cnblogs.com/yunerlalala/p/6119833.html
Copyright © 2011-2022 走看看