zoukankan      html  css  js  c++  java
  • 决策树(一)

    简介

    决策树作为预测模型在统计、数据挖掘和机器学习中应用广泛。决策树结构中,叶节点表示分类,而树中分支表示特征的结合产生某个分类。决策树模型中目标变量如果是一个有限的值集合,则称分类树,如果是连续的变量(通常为实数)则称为回归树。

    决策树学习的目标是创建一个模型,使得能够根据一组输入变量预测输出变量。例如下图,每个内部节点对应一个输入变量,每个输入变量(节点)根据不同的取值形成到不同的子节点的边(edge)。每个叶节点表示目标变量的一个值,这个值由根节点到这个叶节点的路径决定。

    假定,输入的每个特征的值域均为有限离散的,目标特征只有一个即表示分类。决策树中非叶节点表示某一个输入特征(或输入特征向量的一维),从某个非叶节点出发,根据对应这个维度的不同取值形成不同的边,并到达子节点,子节点可能是一颗子决策树或者是叶节点(代表某个分类)。

     决策树类型

    主要分两类

    • 分类树,预测一个新的输入实例属于哪个分类
    • 回归树,预测输出变量的值(通常为实数,例如一套房子的价格,或者一位病人的住院时间)

    术语Classification and regression tree(CART) 分析是一个涵盖性术语,通常指上面的分析过程。分类和回归树有一些相似的地方(当然也有区别,比如切分数据集的过程不同)

    以下一组方法可以构建一个或多个决策树

    Boosted Trees 

    Bootstrap aggregated

    • Random Forest是一个特殊的Bootstrap aggregating

    Rotation forest

    决策树学习就是从训练数据集构建决策树。决策树算法有很多,其中比较有名的如下:

    • ID3 (Iterative Dichotomiser 3)
    • C4.5 (successor of ID3)
    • CART (Classification And Regression Tree)
    • CHAID (CHi-squared Automatic Interaction Detector). Performs multi-level splits when computing classification trees.
    • MARS: extends decision trees to handle numerical data better.
    • Conditional Inference Trees. Statistics-based approach that uses non-parametric tests as splitting criteria, corrected for multiple testing to avoid overfitting. This approach results in unbiased predictor selection and does not require pruning

    度量

    构建决策树的算法通常都是从上至下的,每一步选择一个输入变量,目的是根据这个变量的不同值将当前的数据点集合切分,使得切分效果最佳。不同的算法使用不同的方式衡量“最佳”。

    基尼不纯度(Gini impurity)

    基尼不纯度在CART(classification and regression tree)算法中使用,是指一种衡量系统混乱程度的测量方法:在一个集合中随机选择一个元素,基于该集合中标签的概率分布为元素分配标签的错误率。基尼不纯度可以这样计算,每一项分配标签为i的概率fi乘以这项没有分配为标签i的概率 1 - fi,得到一个积,最后相加所有标签的积。

    其中fi表示数据集中标签i的出现次数与数据集数量的比例,J表示总共有J个标签。

    其实上面的说法不是很明白,简单点说,假设一个数据集中有N个数据点,总共涉及 J 个标签,第 j 个标签出现的 次数为 nj, 那么随机抽取一个数据点,标签 j 出现的概率为 fj = nj/N,那么对于标签 j 来说,随机选择一个数据点,本该被判断为标签j 但却错误的被判断为其他标签,则 这个概率为 fj(1-fj),基尼不纯度就是 每个标签 的这种概率之和,表征一个系统的混乱程度,

     信息增益(Information gain)

    使用信息增益的算法有ID3, C4.5以及C5.0。信息增益基于信息理论中的熵的概念。

    信息熵是表示随机变量不确定性的度量。设X为一个离散随机变量,概率分布为

    P(X=xi) = pi, i=1,2,...,n ,X的取值范围为n个离散值

    则X的熵定义为

       (1)

    其中 log 是 以 2 为底 的对数函数。

    熵越大,表示X不确定越大。

    设有随机变量(X, Y),联合概率分布为

    P(X=xi, Y=yj)=pij, i=1,2, ... ,n; j= 1,2, ... , m

    条件熵H(Y|X)表示在已知随机变量X的条件下Y的不确定性。H(Y|X)定义为给定X的条件下Y的条件概率分布的熵对X的数学期望

         (2)

    其中 表示 X=xi时Y=yj的条件概率

    熵和条件熵的概率由数据估计得到时,对应的熵和条件熵分别成为经验熵和经验条件熵。 

    特征A对训练数据集D的信息增益定义为,集合D的经验熵H(D)与特征A给定条件下D的经验条件熵H(D|A)之差

          (3)

    信息增益用于决定使用哪个特征来切分数据集来生成决策树。简单即是最好,所以我们尽量使得树越小越好,每一步的切分都使得子节点尽可能的“纯”。衡量“纯”的一个方法是信息,信息单位为比特(bit,与计算机中的bit不同)。

    给定训练数据集D和特征A,经验熵H(D)表示对数据集D分类的不确定性(混乱性,无序性),熵越高,不确定性越大(刚才说的“纯”,就是指熵值越小越“纯”)。而条件经验熵H(D|A)表示特征A给定的情况下对数据集D分类的不确定性。两者的差即信息增益,表示由于特征A确定后使得对数据集D分类的不确定性减少的程度,显然由于不确定性减少,所以信息增益为非负。信息增益越大,切分后的子节点的熵越小,表示这个特征的分类能力越强。

    信息增益的计算

    假设训练数据集为D,|D|表示数据集样本容量(个数),K个分类,每个分类为Ck,k=1,2, ... , K ,数据集中分类为Ck的样本个数为|Ck|,设特征A有n个不同的取值{a1, a2, ... , an},根据特征A可以将数据集D切分为n个子集D1,D2, ... , Dn(集合论中称为D的一个划分partition),子集Di的样本个数为|Di|,Di中分类为Ck的样本集合为Dik,Di中分类为Ck个数为|Dik|,显然有

    数据集D的经验熵(经验熵,即根据已知样本数据集D估计特征空间的数据分类不确定性,所以此时,分类是随机变量,根据(1)式有如下等式)

      , k = 1, 2, ... , K,(4)

    特征A对数据集D的经验条件熵(根据(2)式)

       (5)

    将(4)带入(5),得

      (6)

    信息增益再使用(3)即可计算得到。

    例子

    考虑以下一个例子,数据集中包含14个数据,数据有4个属性(4个特征):天气(晴,阴,雨),温度(热,温,凉),湿度(高,低),有风(是,否),目标变量为 “出去玩”,值为(是,否)。一开始,基于每个属性,我们都可以用来切分数据集,然后开始构造决策树,这样就可以有4个决策树,当然,我们需要比较这4个决策树的信息增益,具有最高信息增益的决策树对应的切分(比如基于天气切分)就是选定的用于第一步的切分(于是第一步根据天气切分),然后继续这样切分,直到子节点已经很“纯”了,或者信息增益为0。

    表征系统的分类不确定性的熵,训练数据集样本个数14,分类中有9个正例,5个负例

    H(D) = -(9/14)*log(9/14) - (5/14)*log(5/14)=0.940286

    比如,基于是否有风来切分数据集,可以获得两个子节点,一个表示有风,一个表示无风,在上面的数据集中,总共6个数据点表示有风,其中表示“玩”和“不玩”的数据点各为3个, 然后剩余的8个数据点表示无风,其中表示“玩”和“不玩”的数据点数量分别为6和2,那么各个子系统的熵为

    H(windy)= - (6/14)*(3/6 * log(3/6))*2=0.428571

    H(windless)= – (8/14)(6/8*log(6/8) + 2/8*log(2/8))=0.4635875

    根据是否有风划分后的信息增益

    g(D|A) = H(D) - H(windy) - H(windless) = 0.0481275

    类似的,根据天气来划分数据集,

    H(sunny)=-(5/14)*(2/5*log(2/5)+3/5*log(3/5))=0.346768

    H(overcast)=-(4/14)*(4/4*log(4/4)+0/4*log(0/4))=0

    H(rain)=-(5/14)*(2/5*log(2/5)+3/5*log(3/5))=0.346768

    根据天气划分后的信息增益

    g(D|A)=H(D)-H(sunny)-H(overcast)-H(rain)=0.24675

    然后再分别计算出根据湿度和温度划分的信息增益,比较这四个信息增益,最大的信息增益对应的特征就是我们第一步所选择的特征(信息增益越大,区分样本的能力就越强,越具有代表性,这是一种自顶向下的贪心策略),使用此特征划分数据集后,继续上面的步骤,计算此后另外三个特征划分后的信息增益,选择信息增益最大对应的特征作为第二步中的特征来划分数据集,依次进行直到信息增益很小或没有特征可以选择为止。

    以上步骤就是ID3算法的基本思想。

    下面总结一下生成决策树的步骤:

    输入:数据集D,特征集A,阈值ε 

    输出:决策树T

    1. 若D中所有实例属于同一类Ck,则T为单节点树,该节点类标记为Ck,返回T
    2. 若A=∅,则T为单节点数,找出最大值|Ck|,将该节点类标记为Ck,返回T
    3. 根据(3)式算出A中各特征对D的信息增益,选择信息增益最大的特征Ag
    4. 如果Ag的信息增益小于阈值,则T为单节点数,找出最大值|Ck|,将该节点类标记为Ck,返回T;
    5. 否则,对Ag的每一个取值,将D划分为ng个子集Di,其中ng为特征Ag的可取值数量,将Di中实例数最大的类作为当前节点的类标记,然后构建ng个子节点,由当前节点和ng个子节点构成决策树T,返回T
    6. 对上面步骤5中的ng个子节点,第i个子节点,以Di为数据集,A-{Ag}为特征集,递归调用步骤1~步骤5,得到子树Ti。

    C4.5算法

    C4.5算法与ID3算法类似,只是将信息增益替换为信息增益比, 然后来选择信息增益比最大的特征(属性),详情可参考下面的代码。

    C4.5还增加了对连续型属性变量的处理,比如将属性排序,然后选择在某个地方(分割值为V)切分为两个部分,分别为属性值大于V和小于V的两部分。具体方法如下:

    假设连续型属性为A,将样本点根据A的值升序排序,样本点个数为N,那最简单的就是有N-1个分割点,分割值可以取两个相邻点A值平均数,分别计算每一个分割点(分割值为V)所分的两部分的信息增益(这两部分分别为属性值大于V和小于V的两部分),从而得到N-1个信息增益,选择信息增益最大的那个分割点,从而将属性A的值分成离散的两部分。

    然而上面这种计算N-1个信息增益的方式计算量一般都很大。一种降低计算量的方法为按属性A的值升序排列后,将决策类型相同的连续样本点的A属性值作为一个整体,不设分割点,比如,

     (图来自这里,侵删)

    属性A为温度,是连续值,右图中红色的线表示分割(从图中可以温度为18时,分割值可能也是18,此时分割的两部分就是 大于等于18和小于等于18 这两部分),温度为6 和8 之间不在设置分割点,因为其分类值一样。这样分割点数量就少了很多。

    注意,选择最优分割点是使用信息增益作为依据,选择了最优分割点后将数据在此属性上分为两部分,然后再将此属性与其他属性采用信息增益比作为判断依据来选择最佳分割属性。

    ref:

    代码

        public class DTree
        {
            private Node _root = new Node();
            /// <summary>
            /// 决策树根结点
            /// </summary>
            public Node Root { get { return _root; } }
            /// <summary>
            /// 给定训练数据集创建决策树
            /// </summary>
            /// <param name="data"></param>
            /// <returns></returns>
            public static DTree Create(Data data)
            {
                var dtree = new DTree();
                var target = data.Target;
                var points = data.Examples;
                Create(data.Examples, data.Attributes.Keys.ToList(), data.Target, dtree._root, data.Attributes);
                return dtree;
            }
    
    
            private static void Create(List<Dictionary<string, string>> points, List<string> attrNames, string target, Node node, Dictionary<string, string[]> attrs)
            {
                if(attrNames.Count == 1)    // 没有属性(或者只剩一个目标target)
                {
                    node.Class = Util.GetDefaultClass(points, target);  // 使用当前数据集中数量最多的分类
                }
                else
                {
                    var bestAttr = Util.GetBestAttr(points, attrNames, target);     // 最佳属性
                    var targetVals = points.Select(p => p[target]);                 // 当前数据集中分类值枚举
                    node.Attr = bestAttr;
                    if (Util.IsSameElem(targetVals))
                    {
                        // 剩余的数据点分类值相同,node为叶节点
                        node.Class = targetVals.First();
                    }
                    else
                    {
                        node.Children = new Dictionary<string, Node>();
                        // 根据最佳属性划分得到子空间枚举,遍历
                        foreach (var v in attrs[bestAttr])
                        {
                            var list = Util.GetSubRegion(points, bestAttr, v);
                            if (list.Count == 0)      
                            {
                                // 子空间无数据点,则使用父空间的默认分类
                                var newNode = new Node();
                                newNode.Class = Util.GetDefaultClass(points, target);
                                node.Children.Add(v, newNode);
                            }
                            else
                            {
                                // 子空间有数据点
                                var newAttrNames = attrNames.Where(n => n != bestAttr).ToList();
                                var newNode = new Node();
                                node.Children.Add(v, newNode);
                                Create(list, newAttrNames, target, newNode, attrs);
                            }
                        }
                    }
                }
            }

    /// <summary>
            /// 给定数据点判别分类
            /// </summary>
            /// <param name="point"></param>
            /// <returns></returns>
            public string Judge(Dictionary<string, string> point) => Judge(_root, point);
    
            private string Judge(Node node, Dictionary<string, string> point)
            {
                if(node.Attr == null)
                {
                    return node.Class;
                }
                else
                {
                    var attrVal = point[node.Attr];
                    return Judge(node.Children[attrVal], point);
                }
            }
        }
    
        public class Node
        {
            /// <summary>
            /// 用于划分的属性名,叶节点为null
            /// </summary>
            public string Attr { get; set; }
            /// <summary>
            /// 节点分类,只有叶节点有分类值,内部节点为null
            /// </summary>
            public string Class { get; set; }
            /// <summary>
            /// 根据属性的取值划分子空间,叶节点为null
            /// key为属性值,value为对应的子树的根结点,表示子空间
            /// </summary>
            public Dictionary<string, Node> Children { get; set; }
        }
    
        public class Util
        {
            /// <summary>
            /// 判断枚举中各元素是否相等
            /// </summary>
            /// <param name="enums"></param>
            /// <returns></returns>
            public static bool IsSameElem(IEnumerable<string> enums)
            {
                string first = null;
                foreach (var @enum in enums)
                {
                    if (first == null)
                        first = @enum;
                    else if (first != @enum)
                        return false;
                }
                return true;
            }
            /// <summary>
            /// 给定数据集以及相关的属性和对应属性值,获取子数据集
            /// </summary>
            /// <param name="points"></param>
            /// <param name="attr"></param>
            /// <param name="val"></param>
            /// <returns></returns>
            public static List<Dictionary<string, string>> GetSubRegion(List<Dictionary<string, string>> points, string attr, string val)
            {
                var list = new List<Dictionary<string, string>>();
                foreach (var point in points)
                {
                    if (point[attr] == val)
                        list.Add(point);
                }
                return list;
            }
    
            /// <summary>
            /// 计算给定数据集系统以及指定属性作为随机变量的经验熵
            /// </summary>
            /// <param name="examples"></param>
            /// <param name="attr"></param>
            /// <returns></returns>
            public static double Entropy(List<Dictionary<string, string>> points, string attr)
            {
                var freqs = new Dictionary<string, double>();      // 指定属性,每个值出现的次数映射
                var count = points.Count;                           // 数据集大小
                foreach(var point in points)
                {
                    var attrVal = point[attr];          // 数据点指定属性的值
                    if (freqs.ContainsKey(attrVal))
                        freqs[attrVal] += 1.0;
                    else
                        freqs[attrVal] = 1.0;
                }
    
                // 计算熵
                double h = 0;
                foreach(var f in freqs)
                {
                    h += (-f.Value / count) * Math.Log(f.Value / count, 2);
                }
                return h;
            }
            /// <summary>
            /// 信息增益
            /// </summary>
            /// <param name="points"></param>
            /// <param name="attr"></param>
            /// <param name="target"></param>
            /// <returns></returns>
            public static double Gain(List<Dictionary<string, string>> points, string attr, string target)
            {
                // 计算条件经验熵,给定attr属性
    
                // 根据attr的取值,划分数据集空间,每个子空间的空间名与子数据集映射
                var subRegions = new Dictionary<string, List<Dictionary<string, string>>>();  // 子空间名为attr的值
                foreach(var point in points)
                {
                    var attrVal = point[attr];
                    if (subRegions.ContainsKey(attrVal))
                        subRegions[attrVal].Add(point);
                    else
                        subRegions[attrVal] = new List<Dictionary<string, string>>() { point };
                }
    
                var count = (double)points.Count;       // 数据集总大小
                double hc = 0;
                foreach (var p in subRegions)
                {
                    var subRegion_Prob = p.Value.Count / count;     // 属性值为attrVal的经验概率
                    var subRegion_Entropy = Entropy(p.Value, target);   // 子系统对target这个随机变量的熵
                    hc += subRegion_Prob * subRegion_Entropy;
                }
    
                var h = Entropy(points, target);    // 划分前的系统对target这个随机变量的经验熵
                return h - hc;          // 增益
            }
            /// <summary>
            /// 获取信息增益比
            /// 给定某属性,求数据集系统的信息增益比
            /// </summary>
            /// <param name="points">数据集</param>
            /// <param name="attr">给定的某属性</param>
            /// <param name="target">目标属性,即,分类,信息熵以分类作为随机变量来计算得到</param>
            /// <returns></returns>
            public static double GainRatio(List<Dictionary<string, string>> points, string attr, string target)
            {
                // 计算条件经验熵,给定attr属性
    
                // 根据attr的取值,划分数据集空间,每个子空间的空间名与子数据集映射
                var subRegions = new Dictionary<string, List<Dictionary<string, string>>>();  // 子空间名为attr的值
                foreach (var point in points)
                {
                    var attrVal = point[attr];
                    if (subRegions.ContainsKey(attrVal))
                        subRegions[attrVal].Add(point);
                    else
                        subRegions[attrVal] = new List<Dictionary<string, string>>() { point };
                }
    
                var count = (double)points.Count;       // 数据集总大小
                double hc = 0;
                foreach (var p in subRegions)
                {
                    var subRegion_Prob = p.Value.Count / count;     // 属性值为attrVal的经验概率
                    var subRegion_Entropy = Entropy(p.Value, target);   // 子系统对target这个随机变量的熵
                    hc += subRegion_Prob * subRegion_Entropy;
                }
    
                var h = Entropy(points, target);    // 划分前的系统对target这个随机变量的经验熵
                return 1 - hc/h;          // 信息增益比
            }
    
            /// <summary>
            /// 获取最佳划分的属性
            /// </summary>
            /// <param name="points"></param>
            /// <param name="attrs"></param>
            /// <param name="target"></param>
            /// <returns></returns>
            public static string GetBestAttr(List<Dictionary<string, string>> points, List<string> attrs, string target)
            {
                string bestAttr = null;
                double maxGain = 0;
    
                foreach(var attr in attrs)
                {
                    if (attr == target) continue;
    
                    var gain = Gain(points, attr, target);      // ID3算法,如果是C4.5算法,则将Gain函数替换为GainRatio函数
                    if(maxGain < gain || bestAttr == null)
                    {
                        bestAttr = attr;
                        maxGain = gain;
                    }
                }
                return bestAttr;
            }
    
            /// <summary>
            /// 获取默认分类,默认分类是指具有分类值的数据点最多
            /// 在属性用完或者增益为0时,使用默认分类
            /// </summary>
            /// <param name="points"></param>
            /// <param name="target"></param>
            /// <returns></returns>
            public static string GetDefaultClass(List<Dictionary<string, string>> points, string target)
            {
                var freqs = new Dictionary<string, int>();  // 分类值与其数量的映射
                foreach(var point in points)
                {
                    var val = point[target];
                    if (freqs.ContainsKey(val))
                        freqs[val] += 1;
                    else
                        freqs[val] = 1;
                }
    
                string cls = null;  // 分类
                int clsCount = 0;   // 分类对应的数量
                foreach(var p in freqs)
                {
                    if(cls == null || p.Value > clsCount)
                    {
                        cls = p.Key;
                        clsCount = p.Value;
                    }
                }
                return cls;
            }
        }
        /// <summary>
        /// 数据
        /// </summary>
        public class Data
        {
            public static Regex attrRegex = new Regex(@"^@ATTRIBUTEs+(?<name>.*?)s+{(?<values>(.+))}$", RegexOptions.Compiled);
            /// <summary>
            /// 数据点集合
            /// </summary>
            public List<Dictionary<string, string>> Examples { get; set; }
            /// <summary>
            /// 属性字典,最后一个为目标属性
            /// key为属性名,value为属性的可取值
            /// </summary>
            public Dictionary<string, string[]> Attributes { get; set; }
            /// <summary>
            /// 目标属性名
            /// </summary>
            public string Target { get; set; }
            /// <summary>
            /// 给定数据文件地址,创建数据对象
            /// </summary>
            /// <param name="path"></param>
            /// <returns></returns>
            public static Data Create(string path)
            {
                var attrs = new List<Tuple<string, string[]>>();
                string target = null;
                var lines = File.ReadAllLines(path);
                var examples = new List<Dictionary<string, string>>();
                foreach (var line in lines)
                {
                    if (string.IsNullOrWhiteSpace(line)) continue;
    
                    if (line.StartsWith("@ATTRIBUTE"))
                    {
                        var match = attrRegex.Match(line);
                        var name = match.Groups["name"].Value;
                        var values = match.Groups["values"].Value.Split(new[] { ',', ' ' }, StringSplitOptions.RemoveEmptyEntries);
    
                        attrs.Add(new Tuple<string, string[]>(name, values));
                        target = name;
                    }
                    else if (line[0] != '@')
                    {
                        var segs = line.Split(new[] { ' ', '	' }, StringSplitOptions.RemoveEmptyEntries);
                        var example = new Dictionary<string, string>();
                        for (int i = 0; i < segs.Length; i++)
                        {
                            var name = attrs[i].Item1;
                            var value = segs[i];
                            example.Add(name, value);
                        }
                        examples.Add(example);
                    }
                }
                var data = new Data() { Examples = examples, Attributes = attrs.ToDictionary(t => t.Item1, t => t.Item2), Target = target };
                return data;
            }
        }

    数据文件为

    @ATTRIBUTE outlook            {sunny, overcast, rain}
    @ATTRIBUTE temperature        {hot, mild, cool}
    @ATTRIBUTE humidity            {high, normal}
    @ATTRIBUTE windy            {true, false}
    @ATTRIBUTE target            {yes, no}
    
    @data
    sunny        hot        high    false    no
    sunny        hot        high    true    no
    overcast    hot        high    false    yes
    rain        mild    high    false    yes
    rain        cool    normal    false    yes
    rain        cool    normal    true    no
    overcast    cool    normal    true    yes
    sunny        mild    high    false    no
    sunny        cool    normal    false    yes
    rain        mild    normal    false    yes
    sunny        mild    normal    true    yes
    overcast    mild    high    true    yes
    overcast    hot        normal    false    yes
    rain        mild    high    true    no

    (代码未测试,仅作参考已帮助理解决策树生成和决策过程)

  • 相关阅读:
    PyQt(Python+Qt)学习随笔:containers容器类部件QStackedWidget重要方法介绍
    什么叫工业4.0,这篇接地气的文章终于讲懂了
    怎样 真正认识一个 人
    华为的绩效管理:减人、增 效、加薪
    羽毛球战术
    魔方教程
    员工培养:事前指导,事后纠正
    一把手瞄准哪里,核心竞争力就在哪里
    海尔的五次战略变革
    如何提高基层员工的执行力
  • 原文地址:https://www.cnblogs.com/sjjsxl/p/6880404.html
Copyright © 2011-2022 走看看