zoukankan      html  css  js  c++  java
  • CART剪枝

    与上篇文章中提到的ID3算法和C4.5算法类似,CART算法也是一种决策树分类算法。CART分类回归树算法的本质也是对数据进行分类的,最终数据的表现形式也是以树形的模式展现的,CART与ID3,C4.5所采用的分类标准是不同了。

    下面列出了其中的一些不同之处:

    1、CART最后形成的树是一个二叉树,每个节点会分成2个节点,左孩子节点和右孩子节点,于是这就要求CART算法在所选定的属性中又要划分出最佳的属性划分值,节点如果选定了划分属性名称还要确定里面按照哪个值做一个二元的划分(为属性的值为一类,否则为零另一类)

    而在ID3和C4.5中是按照分类属性的值类型进行划分(属性的取值可以为1个也可以为多个)

    2、CART算法对于属性的值采用的是基于Gini系数值的方式做比较,gini某个属性的某次值的划分的gini指数的值为:

    ,pk就是分别为正负实例的概率,gini系数越小说明分类纯度越高,可以想象成与熵的定义一样。因此在最后计算的时候我们只取其中值最小的做出划分。最后做比较的时候用的是gini的增益做比较,要对分类号的数据做出一个带权重的gini指数的计算

    3

    CART算法在按照Gini指数构建好的树,但是这样构建的树和ID3和C4.5一样,容易导致过拟合现象,在数据集中,在测试集中过拟合的决策树的错误率比经过简化的决策树的错误率要高,过拟合的决策树对训练集拟合的很好,错误率很低(但这并不代表这样的模型是最好的)。 

    现在问题就在于,如何(HOW)在原生的过拟合决策树的基础上,生成简化版的决策树,可以通过剪枝的方法来简化过拟合的决策树。剪枝可以分为两种:预剪枝(Pre-Pruning)和后剪枝(Post-Pruning),下面我们来详细学习下这两种方法: 
    PrePrune:预剪枝,及早的停止树增长,方法可以参考见上面树停止增长的方法。 
    PostPrune:后剪枝,在已生成过拟合决策树上进行剪枝,可以得到简化版的剪枝决策树。 


    常见的后剪枝发包括代价复杂度剪枝,悲观误差剪枝等等。为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝,代价复杂度剪枝的算法公式为:

    α表示的是每个非叶子节点的误差增益率,可以理解为误差代价,最后选出误差代价最小的一个节点进行剪枝(减掉此分支和原来的树的训练集的准确度相差最小)。

    里面变量的意思为:

    是子树中包含的叶子节点个数;

    是节点t的误差代价,如果该节点被剪枝;

    r(t)是节点t的误差率;

    p(t)是节点t上的数据占所有数据的比例。

    是子树Tt的误差代价,如果该节点不被剪枝。它等于子树Tt上所有叶子节点的误差代价之和。下面说说我对于这个公式的理解:其实这个公式的本质是对于剪枝前和剪枝后的样本偏差率做一个差值比较,一个好的分类当然是分类后的样本偏差率相较于没分类(就是剪枝掉的时候)的偏差率小,所以这时的值就会大,如果分类前后基本变化不大,则意味着分类不起什么效果,α值的分子位置就小,所以误差代价就小,可以被剪枝。但是一般分类后的偏差率会小于分类前的,因为偏差数在高层节点的时候肯定比子节点的多,子节点偏差数最多与父亲节点一样。

    主程序如下:

     1 package DataMining_CART;
     2 
     3 public class Client {
     4     public static void main(String[] args){
     5         String filePath = "E:\code\data mining\DataMining_CART\src\DataMining_CART\input.txt";//训练数据集的路径
     6         
     7         CARTTool tool = new CARTTool(filePath);//构造函数
     8         
     9         tool.startBuildingTree();//CART主程序的入口
    10     }
    11 }
    View Code

    每个节点的类如下

     1 package DataMining_CART;
     2 
     3 import java.util.ArrayList;
     4 
     5 /**
     6  * 回归分类树节点
     7  * 
     8  * @author clj
     9  * 
    10  */
    11 public class AttrNode {
    12     // 节点属性名字
    13     private String attrName;
    14     // 节点索引标号
    15     private int nodeIndex;
    16     //包含的叶子节点数
    17     private int leafNum;//记录叶子节点的个数的目的是什么呢
    18     // 节点误差率
    19     private double alpha;
    20     // 父亲分类属性值
    21     private String parentAttrValue;
    22     // 孩子节点
    23     private AttrNode[] childAttrNode;
    24     // 数据记录索引
    25     private ArrayList<String> dataIndex;
    26 
    27     public String getAttrName() {
    28         return attrName;
    29     }
    30 
    31     public void setAttrName(String attrName) {
    32         this.attrName = attrName;
    33     }
    34 
    35     public int getNodeIndex() {
    36         return nodeIndex;
    37     }
    38 
    39     public void setNodeIndex(int nodeIndex) {
    40         this.nodeIndex = nodeIndex;
    41     }
    42 
    43     public double getAlpha() {
    44         return alpha;
    45     }
    46 
    47     public void setAlpha(double alpha) {
    48         this.alpha = alpha;
    49     }
    50 
    51     public String getParentAttrValue() {
    52         return parentAttrValue;
    53     }
    54 
    55     public void setParentAttrValue(String parentAttrValue) {
    56         this.parentAttrValue = parentAttrValue;
    57     }
    58 
    59     public AttrNode[] getChildAttrNode() {
    60         return childAttrNode;
    61     }
    62 
    63     public void setChildAttrNode(AttrNode[] childAttrNode) {
    64         this.childAttrNode = childAttrNode;
    65     }
    66 
    67     public ArrayList<String> getDataIndex() {
    68         return dataIndex;
    69     }
    70 
    71     public void setDataIndex(ArrayList<String> dataIndex) {
    72         this.dataIndex = dataIndex;
    73     }
    74 
    75     public int getLeafNum() {
    76         return leafNum;
    77     }
    78 
    79     public void setLeafNum(int leafNum) {
    80         this.leafNum = leafNum;
    81     }
    82     
    83     
    84     
    85 }
    View Code

    关键程序CARTTool如下

      1 package DataMining_CART;
      2 
      3 import java.io.BufferedReader;
      4 import java.io.File;
      5 import java.io.FileReader;
      6 import java.io.IOException;
      7 import java.util.ArrayList;
      8 import java.util.HashMap;
      9 import java.util.LinkedList;
     10 import java.util.Queue;
     11 
     12 
     13 
     14 /**
     15  * CART分类回归树算法工具类
     16  * 本文是使用的后剪枝的方法
     17  * @author clj
     18  * 
     19  */
     20 public class CARTTool {
     21     // 类标号的值类型
     22     private final String YES = "Yes";
     23     //private final String NO = "No";
     24 
     25     // 所有属性的类型总数,在这里就是data源数据的列数
     26     private int attrNum;
     27     private String filePath;
     28     // 初始源数据,用一个二维字符数组存放模仿表格数据
     29     private String[][] data;
     30     // 数据的属性行的名字
     31     private String[] attrNames;
     32     // 每个属性的值所有类型
     33     private HashMap<String, ArrayList<String>> attrValue;
     34 
     35     public CARTTool(String filePath) {
     36         this.filePath = filePath;
     37         attrValue = new HashMap<>();
     38     }
     39 
     40     /**
     41      * 从文件中读取数据
     42      */
     43     public void readDataFile() {
     44         File file = new File(filePath);
     45         ArrayList<String[]> dataArray = new ArrayList<String[]>();
     46 
     47         try {
     48             BufferedReader in = new BufferedReader(new FileReader(file));
     49             String str;
     50             String[] tempArray;
     51             while ((str = in.readLine()) != null) {
     52                 tempArray = str.split(" ");
     53                 dataArray.add(tempArray);
     54             }
     55             in.close();
     56         } catch (IOException e) {
     57             e.getStackTrace();
     58         }
     59         
     60         data = new String[dataArray.size()][];
     61         dataArray.toArray(data);
     62         attrNum = data[0].length;
     63         attrNames = data[0];
     64     }
     65 
     66     /**
     67      * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
     68      */
     69     public void initAttrValue() {
     70         ArrayList<String> tempValues;
     71 
     72         // 按照列的方式,从左往右找
     73         for (int j = 1; j < attrNum; j++) {
     74             // 从一列中的上往下开始寻找值
     75             tempValues = new ArrayList<>();
     76             for (int i = 1; i < data.length; i++) {
     77                 if (!tempValues.contains(data[i][j])) {
     78                     // 如果这个属性的值没有添加过,则添加
     79                     tempValues.add(data[i][j]);
     80                 }
     81             }
     82 
     83             // 一列属性的值已经遍历完毕,复制到map属性表中
     84             attrValue.put(data[0][j], tempValues);
     85         }
     86     }
     87 
     88     /**
     89      * 计算机基尼指数
     90      * 
     91      * @param remainData
     92      *            剩余数据
     93      * @param attrName
     94      *            属性名称
     95      * @param value
     96      *            属性值
     97      * @param beLongValue
     98      *            分类是否属于此属性值,这里的作用是ID3Tool中的isParent的作用是相同的
     99      * 在belongValue的模式下,在remainData数据中计算属性为atrrName,属性值为value的Gini数
    100      * @return
    101      */
    102     public double computeGini(String[][] remainData, String attrName,
    103             String value, boolean beLongValue) {
    104         // 实例总数
    105         int total = 0;
    106         // 正实例数
    107         int posNum = 0;
    108         // 负实例数
    109         int negNum = 0;
    110         // 基尼指数
    111         double gini = 0;
    112 
    113         // 还是按列从左往右遍历属性
    114         for (int j = 1; j < attrNames.length; j++) {
    115             // 找到了指定的属性
    116             if (attrName.equals(attrNames[j])) {
    117                 for (int i = 1; i < remainData.length; i++) {
    118                     // 统计正负实例按照属于和不属于值类型进行划分
    119                     if ((beLongValue && remainData[i][j].equals(value))
    120                             || (!beLongValue && remainData[i][j].equals(value))) {
    121                         if (remainData[i][attrNames.length - 1].equals(YES)) {
    122                             // 判断此行数据是否为正实例
    123                             posNum++;
    124                         } else {
    125                             negNum++;
    126                         }
    127                     }
    128                 }
    129             }
    130         }
    131 
    132         total = posNum + negNum;
    133         double posProbobly = (double) posNum / total;//计算得到的正比例
    134         double negProbobly = (double) negNum / total;
    135         gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;//统计学习方法中P69中5.24相同
    136 
    137         // 返回计算基尼指数
    138         return gini;
    139     }
    140 
    141     /**
    142      * 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中
    143      * 
    144      * @param remainData
    145      *            剩余谁
    146      * @param attrName
    147      *            属性名称
    148      * 与ID3中的概念是不同的,在查找最优的属性是时查找的基尼指数最小的,而在ID3中查找的是信息增益的最大值
    149      * @return
    150      */
    151     public String[] computeAttrGini(String[][] remainData, String attrName) {
    152         String[] str = new String[2];//第一个返回值是应该以哪个属性为分割属性,第二个返回值是此值对应的Gini属性
    153         // 最终该属性的划分类型值
    154         String spiltValue = "";
    155         // 临时变量
    156         int tempNum = 0;
    157         // 保存属性的值划分时的最小的基尼指数,开始并对其进行初始化
    158         double minGini = Integer.MAX_VALUE;
    159         ArrayList<String> valueTypes = attrValue.get(attrName);
    160         // 属于此属性值的实例数
    161         HashMap<String, Integer> belongNum = new HashMap<>();
    162 
    163         for (String string : valueTypes) {//计算每个atrr的值的在剩余的集合中占有多少
    164             // 重新计数的时候,数字归0
    165             tempNum = 0;
    166             // 按列从左往右遍历属性
    167             for (int j = 1; j < attrNames.length; j++) {
    168                 // 找到了指定的属性
    169                 if (attrName.equals(attrNames[j])) {
    170                     for (int i = 1; i < remainData.length; i++) {
    171                         // 统计正负实例按照属于和不属于值类型进行划分
    172                         if (remainData[i][j].equals(string)) {
    173                             tempNum++;
    174                         }
    175                     }
    176                 }
    177             }
    178 
    179             belongNum.put(string, tempNum);
    180         }
    181 
    182         double tempGini = 0;
    183         double posProbably = 1.0;
    184         double negProbably = 1.0;
    185         for (String string : valueTypes) {
    186             tempGini = 0;
    187 
    188             posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);
    189             negProbably = 1 - posProbably;
    190 
    191             tempGini += posProbably
    192                     * computeGini(remainData, attrName, string, true);
    193             tempGini += negProbably
    194                     * computeGini(remainData, attrName, string, false);
    195 
    196             if (tempGini < minGini) {
    197                 minGini = tempGini;
    198                 spiltValue = string;
    199             }
    200         }
    201 
    202         str[0] = spiltValue;//某个属性的值可能不止一个,所以某属性的Gini应该为选择某属性的最小值
    203         
    204         str[1] = minGini + "";
    205 
    206         return str;
    207     }
    208     /*
    209      * 构建的树的过程,与ID3是差不多的,只不过使用的策略不一样
    210      * 在CART中重点是剪枝,应为只是构建树的过程就会导致树的构建是过拟合的过程,通过对树的剪枝,少去过拟合的现象 
    211      */
    212 
    213     public void buildDecisionTree(AttrNode node, String parentAttrValue,
    214             String[][] remainData, ArrayList<String> remainAttr,
    215             boolean beLongParentValue) {
    216         // 属性划分值
    217         String valueType = "";
    218         // 划分属性名称
    219         String spiltAttrName = "";
    220         double minGini = Integer.MAX_VALUE;
    221         double tempGini = 0;
    222         // 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值
    223         String[] giniArray;
    224 
    225         if (beLongParentValue) {//belongparentValue指的是否是父节点指定的属性
    226             node.setParentAttrValue(parentAttrValue);
    227         } else {
    228             node.setParentAttrValue("!" + parentAttrValue);
    229         }
    230 
    231         if (remainAttr.size() == 0) {
    232             if (remainData.length > 1) {
    233                 ArrayList<String> indexArray = new ArrayList<>();
    234                 for (int i = 1; i < remainData.length; i++) {
    235                     indexArray.add(remainData[i][0]);//0标记是id 编号
    236                 }
    237                 node.setDataIndex(indexArray);//数据记录索引,用来区分这个节点有哪些个节点
    238             }
    239             System.out.println("attr remain null");
    240             return;
    241         }
    242 
    243         for (String str : remainAttr) {//选择属性中的最小基尼指数
    244             giniArray = computeAttrGini(remainData, str);
    245             tempGini = Double.parseDouble(giniArray[1]);
    246 
    247             if (tempGini < minGini) {
    248                 spiltAttrName = str;
    249                 minGini = tempGini;
    250                 valueType = giniArray[0];
    251             }
    252         }
    253         // 移除划分属性
    254         remainAttr.remove(spiltAttrName);
    255         node.setAttrName(spiltAttrName);//那么Node节点是按照splitAttrName进行划分的,所以Node.setAttrName是splitAttrName
    256 
    257         // 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点,孩子节点是和此节点是具有相同的属性
    258         AttrNode[] childNode = new AttrNode[2];
    259         String[][] rData;
    260 
    261         boolean[] bArray = new boolean[] { true, false };
    262         for (int i = 0; i < bArray.length; i++) {
    263             // 二元划分属于属性值的划分,第一次循环中得到的属于splitAttraname 并且属于valueType的节点,第二次循环不属于这个valueType的节点
    264             rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);
    265 
    266             boolean sameClass = true;
    267             ArrayList<String> indexArray = new ArrayList<>();
    268             for (int k = 1; k < rData.length; k++) {
    269                 indexArray.add(rData[k][0]);
    270                 // 判断是否为同一类的
    271                 if (!rData[k][attrNames.length - 1]
    272                         .equals(rData[1][attrNames.length - 1])) {
    273                     // 只要有1个不相等,就不是同类型的
    274                     sameClass = false;
    275                     break;
    276                 }
    277             }
    278 
    279             childNode[i] = new AttrNode();
    280             if (!sameClass) {
    281                 // 创建新的对象属性,对象的同个引用会出错
    282                 ArrayList<String> rAttr = new ArrayList<>();
    283                 for (String str : remainAttr) {
    284                     rAttr.add(str);
    285                 }
    286                 buildDecisionTree(childNode[i], valueType, rData, rAttr,
    287                         bArray[i]);
    288             } else {
    289                 String pAtr = (bArray[i] ? valueType : "!" + valueType);
    290                 childNode[i].setParentAttrValue(pAtr);
    291                 childNode[i].setDataIndex(indexArray);
    292             }
    293         }
    294 
    295         node.setChildAttrNode(childNode);//肯定是要等到子节点处理完了才能把子节点放到孩子节点结合中去。
    296     }
    297 
    298     /**
    299      * 属性划分完毕,进行数据的移除
    300      * 
    301      * @param srcData
    302      *            源数据
    303      * @param attrName
    304      *            划分的属性名称
    305      * @param valueType
    306      *            属性的值类型
    307      * @parame beLongValue 分类是否属于此值类型
    308      */
    309     private String[][] removeData(String[][] srcData, String attrName,
    310             String valueType, boolean beLongValue) {
    311         String[][] desDataArray;
    312         ArrayList<String[]> desData = new ArrayList<>();
    313         // 待删除数据
    314         ArrayList<String[]> selectData = new ArrayList<>();
    315         selectData.add(attrNames);
    316 
    317         // 数组数据转化到列表中,方便移除
    318         for (int i = 0; i < srcData.length; i++) {
    319             desData.add(srcData[i]);
    320         }
    321 
    322         // 还是从左往右一列列的查找
    323         for (int j = 1; j < attrNames.length; j++) {
    324             if (attrNames[j].equals(attrName)) {
    325                 for (int i = 1; i < desData.size(); i++) {
    326                     if (desData.get(i)[j].equals(valueType)) {
    327                         // 如果匹配这个数据,则移除其他的数据
    328                         selectData.add(desData.get(i));
    329                     }
    330                 }
    331             }
    332         }
    333 
    334         if (beLongValue) {
    335             desDataArray = new String[selectData.size()][];
    336             selectData.toArray(desDataArray);
    337         } else {
    338             // 属性名称行不移除
    339             selectData.remove(attrNames);
    340             // 如果是划分不属于此类型的数据时,进行移除
    341             desData.removeAll(selectData);//这里就相当于求了一个补集,与上面的区别就是差这一步和上一步
    342             desDataArray = new String[desData.size()][];
    343             desData.toArray(desDataArray);
    344         }
    345 
    346         return desDataArray;
    347     }
    348 
    349     public void startBuildingTree() {
    350         readDataFile();//将一些文件读入到里面,一些变量进行初始化
    351         initAttrValue();
    352 
    353         ArrayList<String> remainAttr = new ArrayList<>();
    354         // 添加属性,除了最后一个类标号属性
    355         for (int i = 1; i < attrNames.length - 1; i++) {
    356             remainAttr.add(attrNames[i]);
    357         }
    358 
    359         AttrNode rootNode = new AttrNode();
    360         buildDecisionTree(rootNode, "", data, remainAttr, false);//所有都全部初始化
    361         setIndexAndAlpah(rootNode, 0, false);
    362         System.out.println("剪枝前:");
    363         showDecisionTree(rootNode, 1);//1指的是blankNum
    364         setIndexAndAlpah(rootNode, 0, true);
    365         System.out.println("
    剪枝后:");
    366         showDecisionTree(rootNode, 1);
    367     }
    368 
    369     /**
    370      * 显示决策树
    371      * 
    372      * @param node
    373      *            待显示的节点
    374      * @param blankNum
    375      *            行空格符,用于显示树型结构
    376      */
    377     private void showDecisionTree(AttrNode node, int blankNum) {
    378         System.out.println();
    379         for (int i = 0; i < blankNum; i++) {
    380             System.out.print("    ");
    381         }
    382         System.out.print("--");
    383         // 显示分类的属性值
    384         if (node.getParentAttrValue() != null
    385                 && node.getParentAttrValue().length() > 0) {
    386             System.out.print(node.getParentAttrValue());
    387         } else {
    388             System.out.print("--");
    389         }
    390         System.out.print("--");
    391 
    392         if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
    393             String i = node.getDataIndex().get(0);
    394             System.out.print("【" + node.getNodeIndex() + "】类别:"
    395                     + data[Integer.parseInt(i)][attrNames.length - 1]);
    396             System.out.print("[");
    397             for (String index : node.getDataIndex()) {
    398                 System.out.print(index + ", ");
    399             }
    400             System.out.print("]");
    401         } else {
    402             // 递归显示子节点
    403             System.out.print("【" + node.getNodeIndex() + ":"
    404                     + node.getAttrName() + "】");
    405             if (node.getChildAttrNode() != null) {
    406                 for (AttrNode childNode : node.getChildAttrNode()) {
    407                     showDecisionTree(childNode, 2 * blankNum);
    408                 }
    409             } else {
    410                 System.out.print("【  Child Null】");
    411             }
    412         }
    413     }
    414 
    415     /**
    416      * 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝
    417      * 
    418      * @param node
    419      *            开始的时候传入的是根节点
    420      * @param index
    421      *            开始的索引号,从1开始
    422      * @param ifCutNode
    423      *            是否需要剪枝
    424      *  计算的过程应该是自上而下的进行
    425      */
    426     private void setIndexAndAlpah(AttrNode node, int index, boolean ifCutNode) {
    427         AttrNode tempNode;
    428         // 最小误差代价节点,即将被剪枝的节点
    429         AttrNode minAlphaNode = null;
    430         double minAlpah = Integer.MAX_VALUE;//即设置alpha的初始值
    431         Queue<AttrNode> nodeQueue = new LinkedList<AttrNode>();
    432 
    433         nodeQueue.add(node);
    434         while (nodeQueue.size() > 0) {
    435             index++;
    436             // 从队列头部获取首个节点并给予编号,并移除,使用队列,是使用的是宽度优先搜索
    437             tempNode = nodeQueue.poll();
    438             tempNode.setNodeIndex(index);
    439             if (tempNode.getChildAttrNode() != null) {
    440                 for (AttrNode childNode : tempNode.getChildAttrNode()) {
    441                     nodeQueue.add(childNode);
    442                 }
    443                 computeAlpha(tempNode);
    444                 if (tempNode.getAlpha() < minAlpah) {
    445                     minAlphaNode = tempNode;
    446                     minAlpah = tempNode.getAlpha();
    447                 } else if (tempNode.getAlpha() == minAlpah) {
    448                     // 如果误差代价值一样,比较包含的叶子节点个数,剪枝有多叶子节点数的节点
    449                     if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) {
    450                         minAlphaNode = tempNode;
    451                     }
    452                 }
    453             }
    454         }
    455 
    456         if (ifCutNode) {
    457             // 进行树的剪枝,让其左右孩子节点为null,ifCutNode为真是指的是剪枝之后,ifCutNode为假的指的是剪枝之前
    458             minAlphaNode.setChildAttrNode(null);
    459         }
    460     }
    461 
    462     /**
    463      * 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝
    464      * 
    465      * @param node
    466      *            待计算的非叶子节点
    467      */
    468     private void computeAlpha(AttrNode node) {
    469         double rt = 0;//这里的rt和Rt和P73中C(t)的表达是相同的
    470         double Rt = 0;
    471         double alpha = 0;
    472         // 当前节点的数据总数
    473         int sumNum = 0;
    474         // 最少的偏差数
    475         int minNum = 0;
    476 
    477         ArrayList<String> dataIndex;
    478         ArrayList<AttrNode> leafNodes = new ArrayList<>();
    479 
    480         addLeafNode(node, leafNodes);//leafNodes开始为空,在addLeafNode函数中进行更新
    481         /*
    482          * System.out.println("node.attr="+node.getAttrName()+"	leafNodes.length="+leafNodes.size());//输出每个节点分成的叶子节点的信息
    483          
    484             for(int i=0;i<leafNodes.size();i++)
    485                 System.out.print(leafNodes.get(i).getDataIndex()+"	");
    486             System.out.println();
    487         */
    488         node.setLeafNum(leafNodes.size());
    489         for (AttrNode attrNode : leafNodes) {//leafNodes中包含Node所分的所有的叶子节点
    490             dataIndex = attrNode.getDataIndex();
    491 
    492             int num = 0;
    493             sumNum += dataIndex.size();
    494             for (String s : dataIndex) {
    495                 // 统计分类数据中的正负实例数
    496                 if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) {
    497                     num++;
    498                 }
    499             }
    500             minNum += num;//minNum保存着node中正实例的个数
    501 
    502             // 取小数量的值部分
    503             if (1.0 * num / dataIndex.size() > 0.5) {
    504                 num = dataIndex.size() - num;
    505             }
    506 
    507             rt += (1.0 * num / (data.length - 1));
    508         }
    509         
    510         //同样取出少偏差的那部分,因为要找到损失函数,所以是要把正实例和负实例中较小的部分
    511         if (1.0 * minNum / sumNum > 0.5) {
    512             minNum = sumNum - minNum;
    513         }
    514 
    515         Rt = 1.0 * minNum / (data.length - 1);
    516         alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1);//计算G(t)
    517         node.setAlpha(alpha);
    518     }
    519 
    520     /**
    521      * 筛选出节点所包含的叶子节点数
    522      * 
    523      * @param node
    524      *            待筛选节点
    525      * @param leafNode
    526      *            叶子节点列表容器
    527      */
    528     private void addLeafNode(AttrNode node, ArrayList<AttrNode> leafNode) {
    529         ArrayList<String> dataIndex;
    530 
    531         if (node.getChildAttrNode() != null) {
    532             for (AttrNode childNode : node.getChildAttrNode()) {
    533                 dataIndex = childNode.getDataIndex();
    534                 if (dataIndex != null && dataIndex.size() > 0) {//dataIndex这个集合的元素
    535                     // 说明此节点为叶子节点
    536                     //System.out.println("leafNode.getLeafNum="+childNode.getLeafNum());
    537                     
    538                     leafNode.add(childNode);
    539                 } else {
    540                     // 如果还是非叶子节点则继续递归调用
    541                     
    542                     //System.out.println("middle Node.getLeafNum="+childNode.getLeafNum());
    543                     addLeafNode(childNode, leafNode);
    544                 }
    545             }
    546         }
    547     }
    548 
    549 }
    View Code

    训练数据input.txt

     1 Rid Age Income Student CreditRating BuysComputer
     2 1 Youth High No Fair No
     3 2 Youth High No Excellent No
     4 3 MiddleAged High No Fair Yes
     5 4 Senior Medium No Fair Yes
     6 5 Senior Low Yes Fair Yes
     7 6 Senior Low Yes Excellent No
     8 7 MiddleAged Low Yes Excellent Yes
     9 8 Youth Medium No Fair No
    10 9 Youth Low Yes Fair Yes
    11 10 Senior Medium Yes Fair Yes
    12 11 Youth Medium Yes Excellent Yes
    13 12 MiddleAged Medium No Excellent Yes
    14 13 MiddleAged High Yes Fair Yes
    15 14 Senior Medium No Excellent No
    View Code

    输出的结果如图所示:

     1 attr remain null
     2 attr remain null
     3 剪枝前:
     4 
     5     --!--【1:Age】
     6         --MiddleAged--【2】类别:Yes[3, 7, 12, 13, ]
     7         --!MiddleAged--【3:Income】
     8                 --High--【4】类别:No[1, 2, ]
     9                 --!High--【5:Student】
    10                                 --Yes--【6:CreditRating】
    11                                                                 --Fair--【8】类别:Yes[5, 9, 10, ]
    12                                                                 --!Fair--【9】类别:No[6, 11, ]
    13                                 --!Yes--【7:CreditRating】
    14                                                                 --Excellent--【10】类别:No[14, ]
    15                                                                 --!Excellent--【11】类别:Yes[4, 8, ]
    16 剪枝后:
    17 
    18     --!--【1:Age】
    19         --MiddleAged--【2】类别:Yes[3, 7, 12, 13, ]
    20         --!MiddleAged--【3:Income】
    21                 --High--【4】类别:No[1, 2, ]
    22                 --!High--【5:Student】
    23                                 --Yes--【6:CreditRating】【  Child Null】
    24                                 --!Yes--【7:CreditRating】
    25                                                                 --Excellent--【10】类别:No[14, ]
    26                                                                 --!Excellent--【11】类别:Yes[4, 8, ]
    View Code
  • 相关阅读:
    简单两行,实现无线WiFi共享上网,手机抓包再也不用愁了
    Windows下Python 3.6 安装BeautifulSoup库
    RSA加密算法破解及原理
    干货,Wireshark使用技巧-过滤规则
    干货:Wireshark使用技巧-显示规则
    干货!链家二手房数据抓取及内容解析要点
    Wireshark分析实战:某达速递登录帐号密码提取
    协议分析中的TCP/IP网络协议
    Wireshark使用教程:不同报文颜色的含义
    VMware kali虚拟机环境配置
  • 原文地址:https://www.cnblogs.com/huicpc0212/p/4361119.html
Copyright © 2011-2022 走看看