zoukankan      html  css  js  c++  java
  • 统计学习方法(五)——决策树

    /*先把标题给写了,这样就能经常提醒自己*/

      决策树是一种容易理解的分类算法,它可以认为是if-then规则的一个集合。主要的优点是模型具有可读性,且分类速度较快,不用进行过多的迭代训练之类。决策树学习通常包括3个步骤:特征选择、决策树的生成和决策树的修剪。比较常用到的算法有ID3、C4.5和CART。

    1. 决策树模型

      决策树是一种树形结构的分类模型,它由结点和有向边组成,结点分为内部结点和叶结点,内部结点表示一个特征或属性,叶结点表示一个类。

    决策树的分类即是从树的根节点开始对实例的某一个特征进行判断,通过内部结点逐步下潜到叶结点的过程。

    2. 特征选择

      特征选择在于选取对训练数据具有分类能力的特征,通常的选择准则是信息增益或信息增益率。为了便于说明,书中给出了一个例子

    希望通过所给的训练数据学习一个贷款申请的决策树,当新客户提出贷款申请时,根据申请人的特征决定是否可贷。

          从认知上个人觉得特征的选择就是找出一些具有代表性,对于分类辨识度高的特征,如此能够快速准确的为实例分类,从数学的角度上来讲,就要涉及到信息论与概率统计中的熵了。在此不赘述太多,直接给出特征选择的算法(信息增益)。

          输入:训练数据集D和特征A;

          输出:特征A对训练数据集D的信息增益 和增益率

     

    (1)   计算数据集D的经验熵

         

    (2)   计算特征A的经验条件熵

         

    (3)   计算信息增益

         

    (4)   信息增益率

         

          对于书中的例子,首先计算经验熵

         

    然后计算各特征的信息增益,分别以 表示年龄、有工作、有房子和信贷情况4个特征,则

          

    分别计算 的信息增益,由于 的信息增益值最大,则选择其为最优特征,当然也可以计算出信息增益率的结果作为选择的依据。

    3. 决策树的生成

    ID3和C4.5算法基本上一样,只是在特征选择的依据上C4.5采用了改进后的信息增益率。因为本文只介绍其中的ID3算法即可。 

    ID3算法步骤

    输入:训练数据集D,特征集A,阈值e

    输出:决策树T

    (1)   若D中所有实例属于同一类Ck,则T为单结点树,并将类Ck作为该结点的类标记,返回T;

    (2)   若A=空,则T为单结点树,将D中实例数最多的类Ck作为结点类标记,返回T;

    (3)   否则,计算A中各特征对D的信息增益,选择信息增益值最大的特征Ag;

    (4)   如果Ag的信息增益小于阈值e,则T为单结点树,将D中最多的类Ck作为结点类标记,返回T;

    (5)   否则,对Ag的每一可能值ai,依Ag=ai将D分割为若干子集Di,将Di中实例数最大多的类作为类标记,构建子结点,由结点及其子结点构成树T,返回T;

    (6)   对于第i个子结点,以Di为训练集,以A-Ag为特征集,递归调用步骤(1)~(5),得到子树Ti,返回Ti。

     

    从描述上感觉决策树的生成还是挺简单明了的,但是具体的实现上树的生成是最最难的,要注意的细节很多,花了俩个晚上才搞好的,遇到了好多坑

    代码块1:信息增益类

    package org.juefan.decisiontree;
    import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import org.juefan.basic.FileIO; import org.juefan.bayes.Data; public class InfoGain { //数据实例存储类 class Data { public ArrayList<Object> x; public Object y; /**读取一行数据转化为标准格式*/ public Data(String content){ String[] strings = content.split(" | |:"); ArrayList<Object> xList = new ArrayList<Object>(); for(int i = 1; i < strings.length; i++){ xList.add(strings[i]); } this.x = new ArrayList<>(); this.x = xList; this.y = strings[0]; } public Data(){ x = new ArrayList<>(); y = 0; } public String toString(){ StringBuilder builder = new StringBuilder(); builder.append("[ "); for(int i = 0; i < x.size() - 1; i++) builder.append(x.get(i).toString()).append(","); builder.append(x.get(x.size() - 1).toString()); builder.append(" ]"); return builder.toString(); } } //返回底数为2的对数值 public static double log2(double d){ return Math.log(d)/Math.log(2); } /** * 计算经验熵 * @param datas 当前数据集,可以为训练数据集中的子集 * @return 返回当前数据集的经验熵 */ public double getEntropy(ArrayList<Data> datas){ int counts = datas.size(); double entropy = 0; Map<Object, Double> map = new HashMap<Object, Double>(); for(Data data: datas){ if(map.containsKey(data.y)){ map.put(data.y, map.get(data.y) + 1); }else { map.put(data.y, 1D); } } for(double v: map.values()) entropy -= (v/counts * log2(v/counts)); return entropy; } /** * 计算条件熵 * @param datas 当前数据集,可以为训练数据集中的子集 * @param feature 待计算的特征位置 * @return 第feature个特征的条件熵 */ public double getCondiEntropy(ArrayList<Data> datas, int feature){ int counts = datas.size(); double condiEntropy = 0; Map<Object, ArrayList<Data>> tmMap = new HashMap<>(); for(Data data: datas){ if(tmMap.containsKey(data.x.get(feature))){ tmMap.get(data.x.get(feature)).add(data); }else { ArrayList<Data> tmDatas = new ArrayList<>(); tmDatas.add(data); tmMap.put(data.x.get(feature), tmDatas); } } for(ArrayList<Data> datas2: tmMap.values()){ condiEntropy += (double)datas2.size()/counts * getEntropy(datas2); } return condiEntropy; } /** * 计算信息增益(ID3算法) * @param datas 当前数据集,可以为训练数据集中的子集 * @param feature 待计算的特征位置 * @return 第feature个特征的信息增益 */ public double getInfoGain(ArrayList<Data> datas, int feature){ return getEntropy(datas) - getCondiEntropy(datas, feature); } /** * 计算信息增益率(C4.5算法) * @param datas 当前数据集,可以为训练数据集中的子集 * @param feature 待计算的特征位置 * @return 第feature个特征的信息增益率 */ public double getInfoGainRatio(ArrayList<Data> datas, int feature){ return getInfoGain(datas, feature)/getEntropy(datas); } }

    代码块2:决策树类

    package org.juefan.decisiontree;
    import java.util.ArrayList; import java.util.List; public class TreeNode { private String feature;  //候选特征 private List<TreeNode> childTreeNode; private String targetFunValue;  //特征对应的值 private String nodeName;  //分类的类别 public TreeNode(String nodeName){ this.nodeName = nodeName; this.childTreeNode = new ArrayList<TreeNode>(); } public TreeNode(){ this.childTreeNode = new ArrayList<TreeNode>(); } public void printTree(){ if(targetFunValue != null) System.out.print("特征值: " + targetFunValue + " "); if(nodeName != null) System.out.print("类型: " + nodeName + " "); System.out.println(); for(TreeNode treeNode: childTreeNode){ System.out.println("当前特征为:" + feature); treeNode.printTree(); } }
    public String getAttributeValue() { return feature; } public void setAttributeValue(String attributeValue) { this.feature = attributeValue; } public List<TreeNode> getChildTreeNode() { return childTreeNode; } public void setChildTreeNode(List<TreeNode> childTreeNode) { this.childTreeNode = childTreeNode; } public String getTargetFunValue() { return targetFunValue; } public void setTargetFunValue(String targetFunValue) { this.targetFunValue = targetFunValue; } public String getNodeName() { return nodeName; } public void setNodeName(String nodeName) { this.nodeName = nodeName; } }

    代码块3:决策树的生成

    package org.juefan.decisiontree;
    
    import java.util.ArrayList;
    import java.util.HashMap;
    import java.util.HashSet;
    import java.util.List;
    import java.util.Map;
    import java.util.Set;
    import org.juefan.basic.FileIO;
    import org.juefan.bayes.Data;
    
    public class DecisionTree {
        public static final double e = 0.1;
        public InfoGain infoGain = new InfoGain();
        
        public TreeNode buildTree(ArrayList<Data> datas, ArrayList<String> featureName){
            TreeNode treeNode = new TreeNode();
            ArrayList<String> feaName = new ArrayList<>();
            feaName = featureName;
            if(isSingle(datas) || getMaxInfoGain(datas) < e){
                treeNode.setNodeName(getLabel(datas).toString());
                return treeNode;
            }else  {
                int feature = getMaxInfoGainFeature(datas);
                treeNode.setAttributeValue(feaName.get(feature + 1));
                ArrayList<String> tList = new ArrayList<>();
                tList = feaName;
                Map<Object, ArrayList<Data>> tMap = new HashMap<>();
                for(Data data: datas){
                    if(tMap.containsKey(data.x.get(feature))){
                        Data tData = new Data();
                        for(int i = 0; i < data.x.size(); i++)
                            if(i != feature)
                                tData.x.add(data.x.get(i));
                        tData.y = data.y;
                        tMap.get(data.x.get(feature)).add(tData);
                    }else {
                        Data tData = new Data();
                        for(int i = 0; i < data.x.size(); i++)
                            if(i != feature)
                                tData.x.add(data.x.get(i));
                        tData.y = data.y;
                        ArrayList<Data> tDatas = new ArrayList<>();
                        tDatas.add(tData);
                        tMap.put(data.x.get(feature),tDatas);
                    }
                }
                List<TreeNode> treeNodes = new ArrayList<>();
                int child = 0;
                for(Object key: tMap.keySet()){
                    //这一步太坑爹了,java的拷背坑真多啊,害我浪费了半天的时间
                    ArrayList<String> tList2 = new ArrayList<>(tList);
                    tList2.remove(feature + 1);
                    treeNodes.add(buildTree(tMap.get(key), tList2));
                    treeNodes.get(child ++).setTargetFunValue(key.toString());
                }
                treeNode.setChildTreeNode(treeNodes);
                feaName.remove(feature + 1);
            }    
            return treeNode;
        }
        
        /**
         * 获取实例中的最大类
         * @param datas 实例集
         * @return 出现次数最多的类
         */
        public Object getLabel(ArrayList<Data> datas){
            Map<Object, Integer> map = new HashMap<Object, Integer>();
            Object label = null;
            int max = 0;
            for(Data data: datas){
                if(map.containsKey(data.y)){
                    map.put(data.y, map.get(data.y) + 1);
                    if(map.get(data.y) > max){
                        max = map.get(data.y);
                        label = data.y;
                    }
                }else {
                    map.put(data.y, 1);
                }
            }
            return label;
        }
        
        /**
         * 计算信息增益(率)的最大值
         * @param datas
         * @return 最大的信息增益值
         */
        public double getMaxInfoGain(ArrayList<Data> datas){
            double max = 0;
            for(int i = 0; i < datas.get(0).x.size(); i++){
                double temp = infoGain.getInfoGain(datas, i);
                if(temp > max)
                    max = temp;
            }
            return max;
        }
        
        /**信息增益最大的特征*/
        public int getMaxInfoGainFeature(ArrayList<Data> datas){
            double max = 0;
            int feature = 0;
            for(int i = 0; i < datas.get(0).x.size(); i++){
                double temp = infoGain.getInfoGain(datas, i);
                if(temp > max){
                    max = temp;
                    feature = i;
                }
            }
            return feature;
        }
        
        /**判断是否只有一类*/
        public boolean isSingle(ArrayList<Data> datas){
            Set<Object> set = new HashSet<>();
            for(Data data: datas)
                set.add(data.y);
            return set.size() == 1? true:false;
        }
        
        public static void main(String[] args) {
            ArrayList<Data> datas = new ArrayList<>();
            FileIO fileIO = new FileIO();
            DecisionTree decisionTree = new DecisionTree();
            fileIO.setFileName(".//file//decision.tree.txt");
            fileIO.FileRead("utf-8");
            ArrayList<String> featureName = new ArrayList<>();
            //获取文件的标头
            for(String string: fileIO.fileList.get(0).split("	"))
                featureName.add(string);
            for(int i = 1; i < fileIO.fileList.size(); i++){
                datas.add(new Data(fileIO.fileList.get(i)));
            }
            TreeNode treeNode = new TreeNode();
            treeNode = decisionTree.buildTree(datas, featureName);
            treeNode.printTree();
        }
    }

     运行情况:

    输入文件 ".//file//decision.tree.txt" 内容为:

    类型 年龄 有工作 有自己的房子 信贷情况
    否 青年 否 否 一般
    否 青年 否 否 好
    是 青年 是 否 好
    是 青年 是 是 一般
    否 青年 否 否 一般
    否 中年 否 否 一般
    否 中年 否 否 好
    是 中年 是 是 好
    是 中年 否 是 非常好
    是 中年 否 是 非常好
    是 老年 否 是 非常好
    是 老年 否 是 好
    是 老年 是 否 好
    是 老年 是 否 非常好
    否 老年 否 否 一般

    运行结果为:

    当前特征为:有自己的房子
    特征值: 是 类型: 是
    当前特征为:有自己的房子
    特征值: 否
    当前特征为:有工作
    特征值: 是 类型: 是
    当前特征为:有工作
    特征值: 否 类型: 否

    对代码有兴趣的可以上本人的GitHub查看:https://github.com/JueFan/StatisticsLearningMethod/

    里面也有具体的实例数据

  • 相关阅读:
    out/host/linuxx86/obj/EXECUTABLES/aapt_intermediates/aapt 64 32 操作系统
    linux 查看路由器 电脑主机 端口号 占用
    linux proc进程 pid stat statm status id 目录 解析 内存使用
    linux vim 设置大全详解
    ubuntu subclipse svn no libsvnjavahl1 in java.library.path no svnjavahl1 in java.library.path no s
    win7 安装 ubuntu 双系统 详解 easybcd 工具 不能进入 ubuntu 界面
    Atitit.json xml 序列化循环引用解决方案json
    Atitit.编程语言and 自然语言的比较and 编程语言未来的发展
    Atitit.跨语言  文件夹与文件的io操作集合  草案
    Atitit.atijson 类库的新特性设计与实现 v3 q31
  • 原文地址:https://www.cnblogs.com/juefan/p/3843560.html
Copyright © 2011-2022 走看看