zoukankan      html  css  js  c++  java
  • ID3算法 决策树 C++实现

    人工智能课的实验。

    数据结构:多叉树

    这个实验我写了好久,开始的时候从数据的读入和表示入手,写到递归建树的部分时遇到了瓶颈,更新样例集和属性集的办法过于繁琐;

    于是参考网上的代码后重新写,建立决策树类,把属性集、样例集作为数据成员加入类中,并设立访问数组,这样每次更新属性集、样例集时只是标记访问数组的对应元素即可,不必实际拷贝。

    主函数:

     1 #include "Decision_tree.h"
     2 using namespace std;
     3 int main()
     4 {
     5     int num_attr,num_example;
     6     char filename[30];
     7     cout << "请输入训练集文件名:" << endl;
     8     cin >> filename;
     9     freopen(filename, "r", stdin);//从样例文件读入训练内容
    10     cin >> num_attr >> num_example;//读入属性个数、例子个数
    11     Decision_tree my_tree=Decision_tree(num_attr,num_example);
    12     fclose(stdin);
    13     freopen("CON", "r", stdin);//重定向标准输入到控制台
    14     my_tree.display_attr();
    15     cout << "决策树已建成,按深度优先遍历结果如下:" << endl;
    16     my_tree.traverse();
    17     do{
    18         cout << "请输入测试数据,格式:属性1值 属性2值..." << endl;
    19         Example test;
    20         for (int i = 0; i < num_attr; i++)
    21             cin >> test.values[i];
    22         int result = my_tree.judge(test);
    23         if (result == 1) cout << "分类结果为P" << endl;
    24         else if (result == -1) cout << "分类结果为N" << endl;
    25         else if (result == -2) cout << "无法根据已有样例集判断" << endl;
    26         cout << "继续吗?(y/n)";
    27         fflush(stdin);
    28     } while (getchar() == 'y');
    29 }

    属性结构体

    struct Attribute//属性
    {
        string name;
        int count;//属性值个数
        int number;//属性的秩
        string values[MAX_VAL];
    };

    样例结构体

    struct Example//样例
    {
        string values[MAX];
        int pn;
        Example(){ pn = 0; }//默认为未分类的
    };

    决策树的结点

    typedef struct Node//树的结点
    {
        Attribute attr;
        Node* children[MAX_VAL];
        int classification[MAX_VAL];
        Node(){}
    }Node;

    决策树类的实现

      1 class Decision_tree//决策树
      2 {
      3     Node *root;
      4     Example e[MAX];//样例全集
      5     Attribute a[MAX_ATTR];//属性全集
      6     int num_attr, num_example;
      7     int visited_exams[MAX];//样例集的访问情况
      8     int visited_attrs[MAX_ATTR];//属性集的访问情况
      9     Node* recursive_build_tree(int left_e[], int left_a[])//递归建树
     10     {
     11         double max = 0;
     12         int max_attr=-1;
     13         for (int i = 0;i<num_attr;i++)
     14         {//求信息增益最大的属性
     15             if (left_a[i]) continue;
     16             double temp = Gain(left_e, i);
     17             if (max<temp)
     18             {
     19                 max = temp;
     20                 max_attr = i;
     21             }
     22         }
     23         if (max_attr == -1) return NULL;//已没有可判的属性,返回空指针
     24         //cout << a[max_attr].name << endl;
     25         //以这个属性为结点,以各属性值为分支递归建树
     26         int p = 0, n = 0;
     27         Node *new_node=new Node();
     28         new_node->attr = a[max_attr];
     29         for (int i = 0; i<a[max_attr].count;i++)
     30         {//遍历这个属性的所有属性值
     31             for (int j = 0; j < num_example;j++)
     32             {//得到第i个属性值的正反例总数
     33                 if (left_e[j]) continue;
     34                 if (!e[j].values[max_attr].compare(a[max_attr].values[i]))
     35                 {//例子和属性都是循秩访问的,所以向量元素的顺序不能变
     36                     if (e[j].pn) p++;
     37                     else n++;
     38                 }
     39             }
     40             //cout << a[max_attr].values[i] << " ";
     41             //cout << p << " " << n << endl;
     42             if (p && !n)//全是正例,不再分
     43             {
     44                 //cout << "P" << endl;
     45                 new_node->classification[i] = 1;
     46                 new_node->children[i] = NULL;
     47             }
     48             else if (n && !p)//全是反例,不再分
     49             {
     50                 //cout << "N" << endl;
     51                 new_node->classification[i] = -1;
     52                 new_node->children[i] = NULL;
     53             }
     54             else if (!p && !n)//例子集已空
     55             {
     56                 //cout << "none" << endl;
     57                 new_node->classification[i] = -2;//表示未训练到这种分类,无法判断
     58                 new_node->children[i] = NULL;
     59             }
     60             else//例子集不空,且尚未能区分正反,更新访问情况,递归
     61             {
     62                 new_node->classification[i] = 0;
     63                 left_a[max_attr] = 1;//更新属性访问情况
     64                 int left_e_next[MAX];//下一轮的例子集(为便于回溯,不修改原例子集)
     65                 for (int k = 0; k < num_example; k++)
     66                     left_e_next[k] = left_e[k];
     67                 for (int j = 0; j < num_example; j++)
     68                 {
     69                     if (left_e[j]) continue;
     70                     if (!e[j].values[max_attr].compare(a[max_attr].values[i]))
     71                         left_e_next[j] = 0;//属性值匹配的例子,入选下一轮例子集
     72                     else left_e_next[j] = 1;//属性值不匹配,筛除
     73                 }
     74                 new_node->children[i] = recursive_build_tree(left_e_next, left_a);//递归
     75                 left_a[max_attr] = 0;//恢复属性访问情况
     76             }
     77             p = 0;
     78             n = 0;
     79         }
     80         return new_node;
     81     }
     82     double I(int p, int n)
     83     {
     84         double a = p / (p + (double)n);
     85         double b = n / (p + (double)n);
     86         if (a == 0 || b == 0) return 0;
     87         return -a*log(a) / log(2) - b*log(b) / log(2);
     88     }
     89     double Gain(int left_e[], int cur_attr)//计算信息增益
     90     {
     91         int sum_p=0, sum_n=0;
     92         int p[10] = { 0 }, n[10] = { 0 };
     93         for (int i = 0; i < num_example; i++)
     94         {//求样例集的p,n
     95             if (left_e[i]) continue;
     96             if (e[i].pn) sum_p++;
     97             else sum_n++;
     98         }
     99         if (!sum_p && !sum_n)
    100         {
    101             //cout << "no more examples!" << endl;
    102             return -1;//样例集是空集
    103         }
    104             
    105         double sum_Ipn = I(sum_p, sum_n);
    106         for (int i = 0; i < a[cur_attr].count; i++)
    107         {//求第i个属性值的p,n
    108             for (int j = 0; j < num_example; j++)
    109             {
    110                 if (left_e[j]) continue;
    111                 if (!e[j].values[cur_attr].compare(a[cur_attr].values[i]))
    112                     if (e[j].pn) p[i]++;
    113                     else n[i]++;
    114             }
    115         }
    116         double E = 0;
    117         for (int i = 0; i < a[cur_attr].count; i++)//计算属性的期望
    118             E += (p[i] + n[i])*I(p[i], n[i]);
    119         E /= (sum_p + sum_n);
    120         //cout << a[cur_attr].name <<sum_Ipn - E << endl;
    121         return sum_Ipn - E;
    122     }
    123     void recursive_traverse(Node *current)//DFS递归遍历
    124     {
    125         if (current == NULL) return;
    126         cout << current->attr.name << endl;
    127         for (int i = 0; i < current->attr.count; i++)
    128         {
    129             cout << current->attr.values[i] << " " << current->classification[i] << endl;
    130             recursive_traverse(current->children[i]);
    131         }
    132     }
    133     int recursive_judge(Example exa, Node *current)
    134     {
    135         for (int i = 0; i < current->attr.count; i++)
    136         {
    137             if (!exa.values[current->attr.number].compare(current->attr.values[i]))
    138             {
    139                 if (current->children[i]==NULL) return current->classification[i];
    140                 else return recursive_judge(exa, current->children[i]);        
    141             }        
    142         }
    143         return 0;
    144     }
    145 public:
    146     Decision_tree(int num1,int num2)
    147     {
    148         
    149         //通过读文件初始化
    150         num_attr = num1;
    151         num_example = num2;
    152 
    153         for (int i = 0; i<num_attr; i++)
    154         {
    155             a[i].number = i;//属性的秩
    156             cin>>a[i].name;//读入属性名
    157             cin>>a[i].count;//读入此属性的属性值个数
    158             for (int j = 0; j<a[i].count; j++)
    159             {
    160                 cin>>a[i].values[j];//读入各属性值
    161             }
    162         }
    163         
    164         for (int i = 0; i<num_example; i++)
    165         {
    166             string temp;
    167             for (int j = 0; j < num_attr; j++)
    168             {
    169                 cin>>e[i].values[j];
    170             }
    171             cin >> temp;
    172             if (!temp.compare("P")) e[i].pn = 1;
    173             else e[i].pn = 0;
    174         }
    175         //检查
    176         /*for (int i = 0; i<num_attr; i++)
    177         {
    178             cout << a[i].name << endl;//读入属性名
    179             for (int j = 0; j<a[i].count; j++)
    180             {
    181                 cout<<a[i].values[j]<<" ";//读入各属性值
    182             }
    183             cout << endl;
    184         }
    185         for (int i = 0; i<num_example; i++)
    186         {
    187             for (int j = 0; j < num_attr; j++)
    188                 cout<<e[i].values[j]<<" ";
    189             cout<<e[i].pn<<endl;
    190             
    191         }
    192         */
    193         memset(visited_exams, 0, sizeof(visited_exams));
    194         memset(visited_attrs, 0, sizeof(visited_attrs));
    195         root = recursive_build_tree(visited_exams,visited_attrs);
    196     }
    197     void traverse()
    198     {
    199         recursive_traverse(root);
    200     }
    201     int judge(Example exa)//判断
    202     {
    203         int result=recursive_judge(exa,root);
    204         return result;
    205     }
    206     void display_attr()//显示属性
    207     {
    208         cout << "There are " << num_attr << " attributes, they are" << endl;
    209         for (int i = 0; i < num_attr; i++)
    210         {
    211             cout << "[" << a[i].name << "]" << endl;
    212             for (int j = 0; j < a[i].count; j++)
    213                 cout << a[i].values[j] << " ";
    214             cout << endl;
    215         }
    216     }
    217 };
    Decision_tree

    现在这个版本的代码用了10小时完成,去检查时被研究生贬得一文不值。。。也的确,现在我们写的实验题目面向的都是规模非常小的问题,自然体会不到自己的代码在大数据面前的劣势。不过我现在确实学得太少了,很多数据结构都没有动手实现过,算法也是。对C++也只能算入了门。俗话说“磨刀不误砍柴工”,“工欲善其事,必先利其器”,先把基础知识学好,多做基本练习,学到的数据结构和算法都动手实现一遍,这样遇到实际问题也好对应到合适的数据结构和算法。

    另外,参照一本书学习好的代码风格和习惯也是很重要的,因为写代码的习惯是思维习惯的反映,而我现在还处于初学者阶段,按照一种典型的流派模仿,构建起自己的思维模式后再谈其他的。

    忽然觉得自己学了快两年编程还这么水实在是不能忍,都怪大一时年少不懂事没好好学基础。。。

    不过,“悟已往之不谏,知来者之可追”,有了方向,一步步走下去就好,不求优于别人,但一定要“优于过去的自己”。

  • 相关阅读:
    System.Data.RealonlyException:列Column1被设置为realonly
    学习java过程中
    在windows server 2008下安装vs2005.打开vs2005的时候老提示要“运行vs2005sp1 建议使用管理员权限”
    windows Server 2008下面运行vs2005的问题
    大飞机MIS系统360把我的Transformer.Service服务杀掉了
    开通博客
    C#中怎样让控件显示在其他控件的上面
    vs2010发布问题
    vs在IE8无法调试的解决方法
    将身份证号粘贴到WPS表格后变成了“科学计数法”的解决方案
  • 原文地址:https://www.cnblogs.com/helenawang/p/4582081.html
Copyright © 2011-2022 走看看