zoukankan      html  css  js  c++  java
  • 决策树ID3算法的java实现

    决策树的分类过程和人的决策过程比较相似,就是先挑“权重”最大的那个考虑,然后再往下细分。比如你去看医生,症状是流鼻涕,咳嗽等,那么医生就会根据你的流鼻涕这个权重最大的症状先认为你是感冒,接着再根据你咳嗽等症状细分你是否为病毒性感冒等等。决策树的过程其实也是基于极大似然估计。那么我们用一个什么标准来衡量某个特征是权重最大的呢,这里有信息增益和基尼系数两个。ID3算法采用的是信息增益这个量。

    根据《统计学习方法》中的描述,G(D,A)表示数据集D在特征A的划分下的信息增益。具体公式:

    G(D,A)=H(D)-H(D|A)。其中H(D)表示数据集D的熵,熵可以用来描述其混乱度,计算公式为

    H(D)=可见对于数据集D而言,|Dk|表示类别为k的数量,其类别越多,越混乱。

    而H(D|A)表示数据集D在A的划分下的的不确定性。他们的差也即是信息增益,表示由于特征A使得数据集D的分类的不确定减少的差,所以这个值越大说明A的分类对D越有效,也就是权重越大。

    H(D|A)=|Dik|表示在特征A中value为i的划分下数据集类别为k的数量。

    有了这两个公式,接下来就可以写代码了。这里为了清晰的表示这个结果,采用了xml来输出。由于刚开始学java所以只能即学即用(java和python简直不能比,python写ID3一百行代码妥妥的搞定,java用了将近300行。。。)

    算法步骤:

    输入:数据集D,特征集A(这里也可以输入一个阀值,如果信息增益小于该阀值就直接作为叶节点,这样可以避免过拟合)

    输出:xml文件

    1 如果D中的类别是同一类,则作为叶节点,标记为该类Ck

    2 如果特征集A中没有特征了,那么作为叶节点,并且用数据集D中类别最多的类作为类标记

    3 对D的各个特征求最大信息增益,选择信息增益最大的特征Ag

    4 对特征Ag中各个值ai继续对数据集进行分割为Di

    5 以Di为数据集,A-Ag为特征集为输入进行1-4步骤

    具体代码:

      1 import java.io.BufferedReader;
      2 import java.io.FileInputStream;
      3 import java.io.FileWriter;
      4 import java.io.IOException;
      5 import java.io.InputStreamReader;
      6 import java.util.ArrayList;
      7 import java.util.HashMap;
      8 import java.util.HashSet;
      9 import java.util.Map;
     10 import java.util.Set;
     11 
     12 import org.dom4j.Document;
     13 import org.dom4j.DocumentHelper;
     14 import org.dom4j.Element;
     15 import org.dom4j.io.XMLWriter;
     16 
     17 
     18 
     19 
     20 
     21 
     22 class Utils{
     23     //用于从文件中获取数据集
     24     public static ArrayList<ArrayList<String>> loadDataSet(String file) throws IOException{
     25         ArrayList<ArrayList<String>> dataSet=new ArrayList<ArrayList<String>>();
     26         FileInputStream fis=new FileInputStream(file);
     27         InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
     28         BufferedReader br=new BufferedReader(isr);
     29         String line="";
     30         line=br.readLine();
     31         while((line=br.readLine())!=null){
     32             String[] words=line.split(",");
     33             ArrayList<String> data=new ArrayList<String>();
     34             for(int i=0;i<words.length;i++){
     35                 data.add(words[i]);
     36             }
     37             dataSet.add(data);
     38         }
     39         br.close();
     40         isr.close();
     41         fis.close();
     42         return dataSet;
     43     }
     44     //用于从文件中获取特征
     45     public static ArrayList<String> loadFeature(String file) throws IOException{
     46         FileInputStream fis=new FileInputStream(file);
     47         InputStreamReader isr=new InputStreamReader(fis,"UTF-8");
     48         BufferedReader br=new BufferedReader(isr);
     49         
     50         String[] line=br.readLine().split(",");
     51         ArrayList<String> feature=new ArrayList<String>();
     52         for(int i=0;i<line.length-1;i++){
     53             feature.add(line[i]);
     54         }
     55         br.close();
     56         isr.close();
     57         fis.close();
     58         return feature;
     59     }
     60     //用于获得数据集中的类别列表
     61     public static ArrayList<String> getClassList(ArrayList<ArrayList<String>> dataSet){
     62         ArrayList<String> classList=new ArrayList<String>();
     63         int length=dataSet.get(0).size();
     64         for(ArrayList<String> data:dataSet){
     65             String label=data.get(length-1);
     66             classList.add(label);
     67         }
     68         return classList;
     69     }
     70     //返回数据集中的特征数
     71     public static int featureNum(ArrayList<ArrayList<String>> dataList){
     72         int len=dataList.get(0).size()-1;
     73         return len;
     74     }
     75     
     76     
     77 //    public static void   writeToXML(String fileName) throws IOException{
     78 //        Document document = DocumentHelper.createDocument();
     79 //        Element root = document.addElement("DecisionTree");
     80 //        Element outlook=root.addElement("outlook");
     81 //        outlook.addAttribute("value","sunny");
     82 //        Element humidity1=outlook.addElement("humidity");
     83 //        humidity1.addAttribute("value","high");
     84 //        humidity1.addText("no");
     85 //        Element humidity2=outlook.addElement("humidity");
     86 //        humidity2.addAttribute("value","normal");
     87 //        humidity2.addText("yes");
     88 //        
     89 //        XMLWriter writer=new XMLWriter(new FileWriter(fileName));
     90 //        writer.write(document);
     91 //        writer.close();
     92 //    }
     93     //用于获得数据集中第index列的map映射,方便后续的遍历value和计算熵
     94     public static Map<String,Integer> getSubMap(ArrayList<ArrayList<String>> dataSet,int index){
     95         int total=dataSet.size();
     96         Map<String,Integer> subMap=new HashMap();
     97         for(ArrayList<String> data:dataSet){
     98             String lable=data.get(index);
     99             if(subMap.get(lable)==null){
    100                 subMap.put(lable,1);
    101             }else{
    102                 subMap.put(lable,subMap.get(lable)+1);
    103             }
    104         }
    105         return subMap;
    106     }
    107     //打印map,用于debug的时候
    108     public static void showMap(Map<String,Integer> map){
    109         for(Map.Entry<String,Integer> entry:map.entrySet()){
    110             System.out.println(entry.getKey()+":"+entry.getValue());
    111         }
    112     }
    113     //求熵
    114     public static double getEntropy(ArrayList<ArrayList<String>> dataSet,int index){
    115         int total=dataSet.size();
    116         Map<String,Integer> subMap=getSubMap(dataSet,index);        
    117         double entropy=0;
    118         for(Map.Entry<String,Integer> entry:subMap.entrySet()){
    119             double temp=entry.getValue()*1.0/total;
    120             entropy+=temp*(Math.log(temp)/Math.log(2));
    121         }
    122         return -entropy;
    123     }
    124     //求信息增益最大的分割点
    125     public static String bestFeatureSplit(ArrayList<ArrayList<String>> dataSet,ArrayList<String> featureList){
    126         int length=dataSet.get(0).size();
    127         double totalEntropy=getEntropy(dataSet,length-1);
    128         
    129         
    130         
    131         int featureNum=dataSet.get(0).size()-1;
    132         int index=-1;
    133         double maxInfoGain=-1;
    134         for(int i=0;i<featureNum;i++){
    135             double entropy=getEntropy(dataSet,i);
    136             Map<String,Integer> map=getSubMap(dataSet,i);//获得该特征下的map
    137             ArrayList<String> lableList=new ArrayList<String>();
    138             double entropySum=0;
    139             
    140             for(Map.Entry<String,Integer> entry:map.entrySet()){//这里的Di就是map中的特征的value值
    141                 Map<String,Integer> subMap=new HashMap();
    142                 
    143                 
    144                 for(ArrayList<String> data:dataSet){
    145                     if(data.get(i).compareTo(entry.getKey())==0){
    146                         if(subMap.get(data.get(length-1))==null){
    147                             
    148                             subMap.put(data.get(length-1),1);
    149                         }else{
    150                             subMap.put(data.get(length-1),subMap.get(data.get(length-1))+1);
    151                         }
    152                     }
    153                 }
    154                 double x=0;
    155                 for(Map.Entry<String,Integer> subEntry:subMap.entrySet()){
    156                     double temp=subEntry.getValue()*1.0/entry.getValue();
    157                     x+=temp*(Math.log(temp)/Math.log(2));
    158                 }
    159                 
    160                 entropySum+=-x*(entry.getValue())/dataSet.size();
    161             }
    162             entropySum=totalEntropy-entropySum;
    163             if(entropySum>maxInfoGain){
    164                 index=i;
    165                 maxInfoGain=entropySum;
    166             }
    167         }
    168         return featureList.get(index);
    169     }
    170     //分割数据集,index为特征的下标
    171     public static ArrayList<ArrayList<String>> splitDataSet(ArrayList<ArrayList<String>> dataSet,int index,String value){
    172         ArrayList<ArrayList<String>> subDataSet=new ArrayList<ArrayList<String>>();
    173         for(ArrayList<String> data:dataSet){        
    174             if(data.get(index).compareTo(value)==0){
    175                 ArrayList<String> temp=new ArrayList<String>();
    176                 for(int i=0;i<data.size();i++){
    177                     if(i!=index){
    178                         temp.add(data.get(i));
    179                     }
    180                 }
    181                 subDataSet.add(temp);
    182             }
    183         }
    184         return subDataSet;
    185     }
    186     //list-》map
    187     public static Map<String,Integer> arrayToMap(ArrayList<String> list){
    188         Map<String,Integer> map=new HashMap();
    189         for(String word:list){
    190             if(map.get(word)==null){
    191                 map.put(word,1);
    192             }else{
    193                 map.put(word,map.get(word)+1);
    194             }
    195         }
    196         return map;
    197     }
    198     //求label中某个数量最多的类别
    199     public static String major(ArrayList<String> labelList){
    200         Map<String,Integer> map=arrayToMap(labelList);
    201         int max=0;
    202         String label="";
    203         for(Map.Entry<String,Integer> entry:map.entrySet()){
    204             if(entry.getValue()>max){
    205                 label=entry.getKey();
    206             }
    207         }
    208         return label;
    209     }
    210     
    211     public static Set<String> getValueFromDataSet(ArrayList<ArrayList<String>> dataSet,int index){
    212         ArrayList<String> values=new ArrayList<String>();
    213         for(ArrayList<String> data:dataSet){
    214             try{
    215             values.add(data.get(index));
    216             }catch(Exception e){
    217                 
    218                 System.out.println("index is "+index);
    219             }
    220         }
    221         Set<String> set=new HashSet();
    222         for(String value:values){
    223             set.add(value);
    224         }
    225         return set;
    226     }
    227     
    228     public static ArrayList<String> copyArrayList(ArrayList<String> src){
    229         ArrayList<String> dest=new ArrayList<String>();
    230         for(String s:src){
    231             dest.add(s);
    232         }
    233         return dest;
    234     }
    235     
    236     
    237     public static void showArrayList(ArrayList<ArrayList<String>> dataSet){
    238         for(ArrayList<String> data:dataSet){
    239             System.out.println(data);
    240         }
    241     }
    242     
    243 }
    244 
    245 
    246 public class DecisionTree {
    247     
    248     
    249     public static int  createTree(ArrayList<ArrayList<String>> dataSet,ArrayList<String> featureList,Element e){
    250         ArrayList<String> labelList=Utils.getClassList(dataSet);//获取数据集中label的列表
    251         if(Utils.arrayToMap(labelList).size()==1){//表示label中只有一种类别,所以此时不需要再分类了
    252             e.addText(labelList.get(0));
    253             return 1;
    254         }
    255         if(dataSet.get(0).size()==1){//表示此时已经没有特征了,所以也不需要再继续了,此时以label中最多的类别为该节点的类别
    256             e.addText(Utils.major(labelList));
    257             return 1;
    258         }
    259         
    260         ArrayList<String> subFeatureList=Utils.copyArrayList(featureList);
    261         
    262         
    263         
    264         String feature=Utils.bestFeatureSplit(dataSet,featureList);
    265         subFeatureList.remove(feature);
    266         int index=featureList.indexOf(feature);
    267 
    268         Set<String> valueSet=Utils.getValueFromDataSet(dataSet,index);
    269 //        Element next=e.addElement(feature);//原来的代码位置
    270         for(String value:valueSet){
    271              Element next=e.addElement(feature);//后来放到这里之后,xml的输出就正确了,原因在于每递归一次就需要创建一个element,所以应该在for内创建。
    272              next.addAttribute("value",value);
    273              ArrayList<ArrayList<String>> subDataSet=Utils.splitDataSet(dataSet,index,value);
    274              createTree(subDataSet,subFeatureList,next);
    275         }
    276         return 1;
    277     }
    278     
    279     public static void main(String[] args) throws IOException {
    280         // TODO Auto-generated method stub
    281         String file="C:/Users/Administrator/Desktop/upload/DT.txt";
    282         String xml="C:/Users/Administrator/Desktop/upload/DT1.xml";
    283         ArrayList<ArrayList<String>> dataSet=Utils.loadDataSet(file);
    284         ArrayList<String> featureList=Utils.loadFeature(file);
    285         Document document = DocumentHelper.createDocument();
    286         Element root = document.addElement("DecisionTree");
    287         createTree(dataSet,featureList,root);
    288         XMLWriter writer=new XMLWriter(new FileWriter(xml));
    289         writer.write(document);
    290         writer.close();
    291         System.out.println("finished");
    292     }
    293 
    294 }

    这次除了算法上的理解更加深刻了外,在java上也学到了些关于xml解析,读写等方法。

    另外对递归的使用也更加形象些,对于递归一个容易错的点就是函数上的参数,一定要认真对待,要清楚该参数该在什么时候初始化,什么时候被用到。我一开始在第269行上就出现错误了,一开始没有考虑清楚这个next该在什么时候分配,后来发现每次创建节点的时候我们在xml就要创建一个相应的节点用来描述他,所以应该是在for循环里面创建,如果在for外面创建就表示,该特征下的所有值都只有一个element。

    当然对于set,map的遍历啥的也更加清晰了。

     

  • 相关阅读:
    线段拟合(带拉格朗日乘子,HGL)
    工作到位的标准
    Git的简单使用
    位移
    java日期格式化(util包下转成sql包下)
    java中继承的概念
    工作流驳回到指定连线节点上
    年终个人总结
    实现多条件查询 匹配数据库字段中多个数据
    activiti和SSH项目做整合
  • 原文地址:https://www.cnblogs.com/sunrye/p/4562517.html
Copyright © 2011-2022 走看看