zoukankan      html  css  js  c++  java
  • 决策树ID3算法示例

    决策树代码如下:

    #include "MyID3.h"
    using namespace std;
    void ReadData()        //读入数据
    {
        ifstream fin("F:\data.txt");
        for(int i=0;i<NUM;i++)
        {
          for(int j=0;j<6;j++)
            {
                fin>>DataTable[i][j];
                cout<<DataTable[i][j]<<"	";
            }
          cout<<endl;
        }
        fin.close();
    }
    
    double ComputLog(double &p)                     //计算以2为底的log
    {
        if(p==0||p==1)
        return 0;
        else
        {
            double result=log(p)/log(2);
            return result;
        }
    }
    
    double ComputInfo(double &p)                             //计算信息熵
    {
        //cout<<"The value of p is: "<<p<<endl;
        double q=1-p;
        double m=1/p;
        double n=1/q;
        return (p*ComputLog(m)+q*ComputLog(n));
    }
    
    void CountInfoNP(int begin,int end,int &CountP,int &CountN)            //搜索的起始位置、终止位置、计数变量
    {
        CountP=0;
        CountN=0;
        for(int i=begin;i<=end;i++)
            if(DataTable[i][5]=="Yes")
                CountP++;
            else
                CountN++;
    }
    
    bool CompareData(string &data,int &count,string &result)                           //判断该属性值是否出现过
    {
        for(int k=0;k<count;k++)
            if(data==DataValueWeight[k].AttriValueName)                 //如果该值出现过,则将其出现次数加一
                {
                    DataValueWeight[k].ValueWeight+=1;
                    if(result=="Yes")
                        DataValueWeight[k].ValuePWeight+=1;
                    else
                        DataValueWeight[k].ValueNWeight+=1;
                    //cout<<"Exist Here"<<endl;
                    return false;
                }
        return true;                                                    //如果该值没有出现过,则返回真值
    }
    
    int SearchData(const int &begin,const int &end,const int &k)        //对于第k列进行检索
    {
        //cout<<"Enter SearchData()  "<<begin<<"  "<<end<<"  "<<k<<endl;
        int count=0;
        for(int i=0;i<VALUENUM;i++)
            {
                DataValueWeight[i].ValueWeight=0;
                DataValueWeight[i].ValueNWeight=0;
                DataValueWeight[i].ValuePWeight=0;
            }
    
        for(int i=begin;i<=end;i++)
            if(i==begin)
               {
                 DataValueWeight[count].AttriValueName=DataTable[i][k];
                 DataValueWeight[count].ValueWeight+=1;
                 if(DataTable[i][5]=="Yes")
                    DataValueWeight[count].ValuePWeight+=1;
                 else
                    DataValueWeight[count].ValueNWeight+=1;
    
                 count++;
               }
            else
            {
                string data=DataTable[i][k];
                string result=DataTable[i][5];
                if(CompareData(data,count,result))                             //如果该值没有出现过
                {
                    DataValueWeight[count].AttriValueName=data;
                    DataValueWeight[count].ValueWeight+=1;
    
    
    
                    if(DataTable[i][5]=="Yes")
                        DataValueWeight[count].ValuePWeight+=1;
                    else
                        DataValueWeight[count].ValueNWeight+=1;
                    count++;
                }
            }
    
    
         //for(int s=0;s<count;s++)
         //   cout<<"Hello: "<<DataValueWeight[s].AttriValueName<<"	"<<DataValueWeight[s].ValueWeight<<
         //   "	"<<DataValueWeight[s].ValuePWeight<<" 	"<<DataValueWeight[s].ValueNWeight<<endl;
    
    
        for(int i=0;i<count;i++)
        {
            if(DataValueWeight[i].ValueNWeight!=0)
                DataValueWeight[i].ValueNWeight=DataValueWeight[i].ValueWeight/DataValueWeight[i].ValueNWeight;
            else
                DataValueWeight[i].ValueNWeight=0;
    
            if(DataValueWeight[i].ValuePWeight!=0)
                DataValueWeight[i].ValuePWeight=DataValueWeight[i].ValueWeight/DataValueWeight[i].ValuePWeight;
            else
                DataValueWeight[i].ValuePWeight=0;
            //cout<<"N: "<<DataValueWeight[i].ValueNWeight<<"  P: "<<DataValueWeight[i].ValuePWeight<<endl;
        }
        return count;
    }
    
    int PickAttri()
    {
        double max=0;
        int pos;
    
        for(int i=1;i<5;i++)
        if(InfoResult[i].AttriI>max)
        {
            pos=i;
            max=InfoResult[i].AttriI;
        }
        return pos;
    }
    int  SortByAttriValue(int &begin,int &end,int &temp,int *position)
    {
    
        for(int i=begin;i<=end;i++)                                         //将相应的数据拷贝到另一个阵列
            for(int j=0;j<=5;j++)
            {
                int posy=i-begin;
                CopyDataTable[posy][j]=DataTable[i][j];
            }
    //cout<<"have a look"<<endl;
    
        /*cout<<"*************         Show Result First        ****************"<<endl;
        cout<<InfoResult[temp].AttriName<<endl;
        for(int i=begin;i<=end;i++)
        {
            for(int j=0;j<=5;j++)
                cout<<DataTable[i][j]<<"	";
            cout<<endl;
        }*/
    
    
    
        int low=0,high=end-begin;
        int count=0;
        int countpos=1;
        position[0]=begin;
        for(int i=0;i<InfoResult[temp].AttriKind;i++)
        {
            for(int j=low;j<=high;j++)
                if(CopyDataTable[j][temp]==DataValueWeight[i].AttriValueName)
                   {
                        int pos=count+begin;
    
                        for(int k=0;k<6;k++)
                            DataTable[pos][k]=CopyDataTable[j][k];
                        count++;
                   }
            position[countpos]=count+begin;
            countpos++;
        }
    
        /*cout<<"*************         Show Result Second        ****************"<<endl;
        cout<<InfoResult[temp].AttriName<<endl;
        for(int i=begin;i<=end;i++)
        {
            for(int j=0;j<=5;j++)
                cout<<DataTable[i][j]<<"	";
            cout<<endl;
        }
        cout<<"
    
    
    ";*/
        return countpos;
    }
    
    void BuildTree(int begin,int end,Node *parent)
    {
        int CountP=0,CountN=0;
        CountInfoNP(begin,end,CountP,CountN);
    
        cout<<"************************ The data be sorted **************************"<<endl;
        for(int i=begin;i<=end;i++)
        {
            for(int j=0;j<=5;j++)
                cout<<DataTable[i][j]<<"	";
            cout<<endl;
        }
        cout<<"
    
    
    ";
    
        cout<<parent->AttriName<<" have a look: "<<CountP<<endl;
        if(CountP==0||CountN==0)               //该子集当中只包含Yes或者No时为叶子节点,返回调用处;
        {
            cout<<"creat leaf node"<<endl;
            Node* t=new Node();                                    //建立叶子节点
            if(CountP==0)
                t->AttriName="No";
            else
                t->AttriName="Yes";
            parent->Children.push_back(t);                             //插入孩子节点
            return;
        }
        else
        {
            double p=(double)CountP/(CountP+CountN);
            double InfoH=ComputInfo(p);                            //获得信息熵
    
            for(int k=1;k<5;k++)                                   //循环计算各个属性的条件信息熵,并计算出互信息
            {
                int KindOfValue=SearchData(begin,end,k);
                int sum=1+end-begin;
                for(int j=0;j<KindOfValue;j++)                     //计算出属性的每种取值的权重的倒数
                    DataValueWeight[j].ValueWeight=DataValueWeight[j].ValueWeight/sum;
    
                double InfoGain=0;
                if(DataValueWeight[0].ValueNWeight!=0&&DataValueWeight[0].ValuePWeight!=0)
                    InfoGain=DataValueWeight[0].ValueWeight*(ComputLog(DataValueWeight[0].ValueNWeight)/DataValueWeight[0].ValueNWeight+ComputLog(DataValueWeight[0].ValuePWeight)/DataValueWeight[0].ValuePWeight);
    
                for(int j=1;j<KindOfValue;j++)                     //计算条件信息
                if(DataValueWeight[j].ValueNWeight!=0&&DataValueWeight[j].ValuePWeight!=0)
                    InfoGain+=DataValueWeight[j].ValueWeight*(ComputLog(DataValueWeight[j].ValueNWeight)/DataValueWeight[j].ValueNWeight+ComputLog(DataValueWeight[j].ValuePWeight)/DataValueWeight[j].ValuePWeight);
    
                InfoResult[k].AttriI=InfoH-InfoGain;               //计算互信息
                InfoResult[k].AttriKind=KindOfValue;
            }
            int temp=PickAttri();                                            //选出互信息最大的属性作为节点建树
            Node* t=new Node();
            t->AttriName=InfoResult[temp].AttriName;
            SearchData(begin,end,temp);
            for(int k=0;k<InfoResult[temp].AttriKind;k++)
            {
                string name=DataValueWeight[k].AttriValueName;
                t->AttriValue.push_back(name);
            }
            t->parent=parent;
            parent->Children.push_back(t);                                   //孩子节点压入vector当中
            int position[NUMOFPOS];
    
            cout<<"before SortByAttriValue Begin: "<<begin<<",END: "<<end<<endl;
    
    
            SortByAttriValue(begin,end,temp,position);                                     //将数据按照选定属性的取值不同进行划分
            int times=InfoResult[temp].AttriKind;
            for(int l=0;l<=times;l++)
                cout<<position[l]<<" ";
            cout<<endl;
            for(int k=0;k<times;k++)
                {
                    int head,rear;
                    head=position[k];
                    int hire=k+1;
                    rear=position[hire]-1;
                    for(int l=0;l<=times;l++)
                    cout<<position[l]<<" ";
                    cout<<endl;
                    cout<<"Head: "<<head<<" ,Rear: "<<rear<<endl;
                    BuildTree(head,rear,t);
                }
        }
    }
    
    void ShowTree(Node *root)
    {
    
        if(root->AttriName=="Yes"||root->AttriName=="No")
        {
            cout<<root->AttriName<<endl;
            return;
        }
        else
        {
            cout<<root->AttriName<<endl;
            for(vector<string>::iterator itvalue=root->AttriValue.begin();itvalue!=root->AttriValue.end();itvalue++)
            {
                string value=*itvalue;
                cout<<value<<" ";
            }
            cout<<endl;
            for(vector<Node*>::iterator itnode=root->Children.begin();itnode!=root->Children.end();itnode++)
            {
                Node *t=*itnode;
                ShowTree(t);
            }
        }
    }
    int main()
    {
        InfoResult[1].AttriName="天气";
        InfoResult[2].AttriName="气温";
        InfoResult[3].AttriName="湿度";
        InfoResult[4].AttriName="";
        ReadData();
        Node *Root=new Node;
        BuildTree(0,NUM-1,Root);
    
        //vector<Node>::iterator it=Root.Children.begin();
        ShowTree(Root);
        /*Node t=*it;
        cout<<t.AttriName<<endl;
        for(vector<string>::iterator itvalue=t.AttriValue.begin();itvalue!=t.AttriValue.end();itvalue++)
        cout<<*itvalue<<endl;
        it=t.Children.begin();
    
        t=*it;
        cout<<t.AttriName<<endl;*/
        //ShowTree(t);
        //cout<<"Root: "<<t.AttriName<<" ,Value: "<<*(t.AttriValue.begin())<<endl;
        return 0;
    }
    态度决定高度,细节决定成败,
  • 相关阅读:
    JAVA日报
    JAVA日报
    JAVA日报
    JAVA
    leetcode刷题笔记 222题 完全二叉树的节点个数
    leetcode刷题笔记 221题 最大正方形
    leetcode刷题笔记 220题 存在重复元素 III
    leetcode刷题笔记 219题 存在重复元素 II
    leetcode刷题笔记 218题 天际线问题
    leetcode刷题笔记 216题 组合总和 III
  • 原文地址:https://www.cnblogs.com/lxk2010012997/p/3693787.html
Copyright © 2011-2022 走看看