zoukankan      html  css  js  c++  java
  • 决策树ID3算法的java实现(基本适用所有的ID3)

    已知:流感训练数据集,预定义两个类别;
    求:用ID3算法建立流感的属性描述决策树
    流感训练数据集

    No.

    头痛

    肌肉痛

    体温

    患流感

    1

    是(1)

    是(1)

    正常(0)

    否(0)

    2

    是(1)

    是(1)

    高(1)

    是(1)

    3

    是(1)

    是(1)

    很高(2)

    是(1)

    4

    否(0)

    是(1)

    正常(0)

    否(0)

    5

    否(0)

    否(0)

    高(1)

    否(0)

    6

    否(0)

    是(1)

    很高(2)

    是(1)

    7

    是(1)

    否(0)

    高(1)

    是(1)

                                                                    

     

     

    原理分析:

     

    在决策树的每一个非叶子结点划分之前,先计算每一个属性所带来的信息增益,选择最大信息增益的属性来划分,因为信息增益越大,区分样本的能力就越强,越具有代表性其中。

     

    信息熵计算:

     

     

    信息增益:

     

     

    计算的结果(草稿上的字丑别喷):

     

    --------------------------------------------------------------------------------------------------------------------------------------------

     

     

    *************************************************************************************************************

    ************************实现*********************************************

     

     

    package ID3Tree;
    import java.util.Comparator;;
    @SuppressWarnings("rawtypes")
    public class Comparisons implements Comparator 
    {
        public int compare(Object a, Object b) throws ClassCastException{
            String str1 = (String)a;
            String str2 = (String)b;
            return str1.compareTo(str2);
        }
    }
    package ID3Tree;
    
    public class Entropy {
        //信息熵
        public static double getEntropy(int x, int total)
        {
            if (x == 0)
            {
                return 0;
            }
            double x_pi = getShang(x,total);
            return -(x_pi*Logs(x_pi));
        }
    
        public static double Logs(double y)
        {
            return Math.log(y) / Math.log(2);
        }
        
        
        public static double getShang(int x, int total)
        {
            return x * Double.parseDouble("1.0") / total;
        }
    }
    package ID3Tree;
    
    public class TreeNode {
        //父节点
        TreeNode parent;
        //指向父节点的属性
        String parentAttribute;
        String nodeName;
        String[] attributes;
        TreeNode[] childNodes;
    }
    package ID3Tree;
    import java.util.*;
    
    public class UtilID3 {
        TreeNode root;
        private boolean[] flag;
        //训练集
        private Object[] trainArrays;
        //节点索引
        private int nodeIndex;
        public static void main(String[] args)
        {
            //初始化训练集数组
            Object[] arrays = new Object[]{
                    new String[]{"是","是","正常","否"},
                    new String[]{"是","是","高","是"},
                    new String[]{"是","是","很高","是"},
                    new String[]{"否","是","正常","否"},
                    new String[]{"否","否","高","否"},
                    new String[]{"否","是","很高","是"},
                    new String[]{"是","否","高","是"}};
            UtilID3 ID3Tree = new UtilID3();
            ID3Tree.create(arrays, 3);
        }
    
        //创建
        public void create(Object[] arrays, int index)
        {
            this.trainArrays = arrays;
            initial(arrays, index);
            createDTree(arrays);
            printDTree(root);
        }
        
        //初始化
        public void initial(Object[] dataArray, int index)
        {
            this.nodeIndex = index;
            
            //数据初始化
            this.flag = new boolean[((String[])dataArray[0]).length];
            for (int i = 0; i<this.flag.length; i++)
            {
                if (i == index)
                {
                    this.flag[i] = true;
                }
                else
                {
                    this.flag[i] = false;
                }
            }
        }
        
        //创建决策树
        public void createDTree(Object[] arrays)
        {
            Object[] ob = getMaxGain(arrays);
            if (this.root == null)
            {
                this.root = new TreeNode();
                root.parent = null;
                root.parentAttribute = null;
                root.attributes = getAttributes(((Integer)ob[1]).intValue());
                root.nodeName = getNodeName(((Integer)ob[1]).intValue());
                root.childNodes = new TreeNode[root.attributes.length];
                insert(arrays, root);
            }
        }
        
        //插入决策树
        public void insert(Object[] arrays, TreeNode parentNode)
        {
            String[] attributes = parentNode.attributes;
            for (int i = 0; i < attributes.length; i++)
            {
                Object[] Arrays = pickUpAndCreateArray(arrays, attributes[i],getNodeIndex(parentNode.nodeName));
                Object[] info = getMaxGain(Arrays);
                double gain = ((Double)info[0]).doubleValue();
                if (gain != 0)
                {
                    int index = ((Integer)info[1]).intValue();
                    TreeNode currentNode = new TreeNode();
                    currentNode.parent = parentNode;
                    currentNode.parentAttribute = attributes[i];
                    currentNode.attributes = getAttributes(index);
                    currentNode.nodeName = getNodeName(index);
                    currentNode.childNodes = new TreeNode[currentNode.attributes.length];
                    parentNode.childNodes[i] = currentNode;
                    insert(Arrays, currentNode);
                }
                else
                {
                    TreeNode leafNode = new TreeNode();
                    leafNode.parent = parentNode;
                    leafNode.parentAttribute = attributes[i];
                    leafNode.attributes = new String[0];
                    leafNode.nodeName = getLeafNodeName(Arrays);
                    leafNode.childNodes = new TreeNode[0];
                    parentNode.childNodes[i] = leafNode;
                }
            }
        }
        
        //输出
        public void printDTree(TreeNode node)
        {
            System.out.println(node.nodeName);
            TreeNode[] childs = node.childNodes;
            for (int i = 0; i < childs.length; i++)
            {
                if (childs[i] != null)
                {
                    System.out.println("如果:"+childs[i].parentAttribute);
                    printDTree(childs[i]);
                }
            }
        }
        
        //剪取数组
        public Object[] pickUpAndCreateArray(Object[] arrays, String attribute, int index)
        {
            List<String[]> list = new ArrayList<String[]>();
            for (int i = 0; i < arrays.length; i++)
            {
                String[] strs = (String[])arrays[i];
                if (strs[index].equals(attribute))
                {
                    list.add(strs);
                }
            }
            return list.toArray();
        }
        
        //取得节点名
        public String getNodeName(int index)
        {
            String[] strs = new String[]{"头痛","肌肉痛","体温","患流感"};
            for (int i = 0; i < strs.length; i++)
            {
                if (i == index)
                {
                    return strs[i];
                }
            }
            return null;
        }
        
        //取得叶子节点名
        public String getLeafNodeName(Object[] arrays)
        {
            if (arrays != null && arrays.length > 0)
            {
                String[] strs = (String[])arrays[0];
                return strs[nodeIndex];
            }
            return null;
        }
        
        //取得节点索引
        public int getNodeIndex(String name)
        {
            String[] strs = new String[]{"头痛","肌肉痛","体温","患流感"};
            for (int i = 0; i < strs.length; i++)
            {
                if (name.equals(strs[i]))
                {
                    return i;
                }
            }
            return -1;
        }
        
        
        
        //得到最大信息增益
        public Object[] getMaxGain(Object[] arrays)
        {
            Object[] result = new Object[2];
            double gain = 0;
            int index = -1;
            for (int i = 0; i<this.flag.length; i++)
            {
                if (!this.flag[i])
                {
                    double value = gain(arrays, i);
                    if (gain < value)
                    {
                        gain = value;
                        index = i;
                    }
                }
            }
            result[0] = gain;
            result[1] = index;
            if (index != -1)
            {
                this.flag[index] = true;
            }
            return result;
        }
        
        //取得属性数组
        public String[] getAttributes(int index)
        {
            @SuppressWarnings("unchecked")
            TreeSet<String> set = new TreeSet<String>(new Comparisons());
            for (int i = 0; i<this.trainArrays.length; i++)
            {
                String[] strs = (String[])this.trainArrays[i];
                set.add(strs[index]);
            }
            String[] result = new String[set.size()];
            return set.toArray(result);
            
        }
        
        //计算信息增益
        public double gain(Object[] arrays, int index)
        {
            String[] playBalls = getAttributes(this.nodeIndex);
            int[] counts = new int[playBalls.length];
            for (int i = 0; i<counts.length; i++)
            {
                counts[i] = 0;
            }
            
            for (int i = 0; i<arrays.length; i++)
            {
                String[] strs = (String[])arrays[i];
                for (int j = 0; j<playBalls.length; j++)
                {    
                    if (strs[this.nodeIndex].equals(playBalls[j]))
                    {
                        counts[j]++;
                    }
                }
            }
            
            double entropyS = 0;
            for (int i = 0;i <counts.length; i++)
            {
                entropyS = entropyS + Entropy.getEntropy(counts[i], arrays.length);
            }
            
            String[] attributes = getAttributes(index);
            double total = 0;
            for (int i = 0; i<attributes.length; i++)
            {
                total = total + entropy(arrays, index, attributes[i], arrays.length);
            }
            return entropyS - total;
        }
        
        
        public double entropy(Object[] arrays, int index, String attribute, int totals)
        {
            String[] playBalls = getAttributes(this.nodeIndex);
            int[] counts = new int[playBalls.length];
            for (int i = 0; i < counts.length; i++)
            {
                counts[i] = 0;
            }
            
            for (int i = 0; i < arrays.length; i++)
            {
                String[] strs = (String[])arrays[i];
                if (strs[index].equals(attribute))
                {
                    for (int k = 0; k<playBalls.length; k++)
                    {
                        if (strs[this.nodeIndex].equals(playBalls[k]))
                        {
                            counts[k]++;
                        }
                    }
                }
            }
            
            int total = 0;
            double entropy = 0;
            for (int i = 0; i < counts.length; i++)
            {
                total = total +counts[i];
            }
            
            for (int i = 0; i < counts.length; i++)
            {
                entropy = entropy + Entropy.getEntropy(counts[i], total);
            }
            return Entropy.getShang(total, totals)*entropy;
        }
    }

  • 相关阅读:
    SQL中关于Left Join转为Inner Join的问题,即左关联转为内关联的问题
    Mybatis Plus 2 升到 Mybatis Plus 3 时,oracle 自增序列的相关问题
    Java项目启动时,oracle 驱动异常
    window 下安装 Arthas
    postman 中给所有接口token授权的配置
    探讨:在循环前与在循环中创建对象的区别
    当你无法发现问题所在时,不要简单地把代码或者数据还原
    http://875880923.iteye.com/blog/1963400
    2013成都网络赛 J A Bit Fun(水题)
    2013成都网络赛 C We Love MOE Girls(水题)
  • 原文地址:https://www.cnblogs.com/tk55/p/6231206.html
Copyright © 2011-2022 走看看