决策树代码如下:
#include "MyID3.h" using namespace std; void ReadData() //读入数据 { ifstream fin("F:\data.txt"); for(int i=0;i<NUM;i++) { for(int j=0;j<6;j++) { fin>>DataTable[i][j]; cout<<DataTable[i][j]<<" "; } cout<<endl; } fin.close(); } double ComputLog(double &p) //计算以2为底的log { if(p==0||p==1) return 0; else { double result=log(p)/log(2); return result; } } double ComputInfo(double &p) //计算信息熵 { //cout<<"The value of p is: "<<p<<endl; double q=1-p; double m=1/p; double n=1/q; return (p*ComputLog(m)+q*ComputLog(n)); } void CountInfoNP(int begin,int end,int &CountP,int &CountN) //搜索的起始位置、终止位置、计数变量 { CountP=0; CountN=0; for(int i=begin;i<=end;i++) if(DataTable[i][5]=="Yes") CountP++; else CountN++; } bool CompareData(string &data,int &count,string &result) //判断该属性值是否出现过 { for(int k=0;k<count;k++) if(data==DataValueWeight[k].AttriValueName) //如果该值出现过,则将其出现次数加一 { DataValueWeight[k].ValueWeight+=1; if(result=="Yes") DataValueWeight[k].ValuePWeight+=1; else DataValueWeight[k].ValueNWeight+=1; //cout<<"Exist Here"<<endl; return false; } return true; //如果该值没有出现过,则返回真值 } int SearchData(const int &begin,const int &end,const int &k) //对于第k列进行检索 { //cout<<"Enter SearchData() "<<begin<<" "<<end<<" "<<k<<endl; int count=0; for(int i=0;i<VALUENUM;i++) { DataValueWeight[i].ValueWeight=0; DataValueWeight[i].ValueNWeight=0; DataValueWeight[i].ValuePWeight=0; } for(int i=begin;i<=end;i++) if(i==begin) { DataValueWeight[count].AttriValueName=DataTable[i][k]; DataValueWeight[count].ValueWeight+=1; if(DataTable[i][5]=="Yes") DataValueWeight[count].ValuePWeight+=1; else DataValueWeight[count].ValueNWeight+=1; count++; } else { string data=DataTable[i][k]; string result=DataTable[i][5]; if(CompareData(data,count,result)) //如果该值没有出现过 { DataValueWeight[count].AttriValueName=data; DataValueWeight[count].ValueWeight+=1; if(DataTable[i][5]=="Yes") DataValueWeight[count].ValuePWeight+=1; else DataValueWeight[count].ValueNWeight+=1; count++; } } //for(int s=0;s<count;s++) // cout<<"Hello: "<<DataValueWeight[s].AttriValueName<<" "<<DataValueWeight[s].ValueWeight<< // " "<<DataValueWeight[s].ValuePWeight<<" "<<DataValueWeight[s].ValueNWeight<<endl; for(int i=0;i<count;i++) { if(DataValueWeight[i].ValueNWeight!=0) DataValueWeight[i].ValueNWeight=DataValueWeight[i].ValueWeight/DataValueWeight[i].ValueNWeight; else DataValueWeight[i].ValueNWeight=0; if(DataValueWeight[i].ValuePWeight!=0) DataValueWeight[i].ValuePWeight=DataValueWeight[i].ValueWeight/DataValueWeight[i].ValuePWeight; else DataValueWeight[i].ValuePWeight=0; //cout<<"N: "<<DataValueWeight[i].ValueNWeight<<" P: "<<DataValueWeight[i].ValuePWeight<<endl; } return count; } int PickAttri() { double max=0; int pos; for(int i=1;i<5;i++) if(InfoResult[i].AttriI>max) { pos=i; max=InfoResult[i].AttriI; } return pos; } int SortByAttriValue(int &begin,int &end,int &temp,int *position) { for(int i=begin;i<=end;i++) //将相应的数据拷贝到另一个阵列 for(int j=0;j<=5;j++) { int posy=i-begin; CopyDataTable[posy][j]=DataTable[i][j]; } //cout<<"have a look"<<endl; /*cout<<"************* Show Result First ****************"<<endl; cout<<InfoResult[temp].AttriName<<endl; for(int i=begin;i<=end;i++) { for(int j=0;j<=5;j++) cout<<DataTable[i][j]<<" "; cout<<endl; }*/ int low=0,high=end-begin; int count=0; int countpos=1; position[0]=begin; for(int i=0;i<InfoResult[temp].AttriKind;i++) { for(int j=low;j<=high;j++) if(CopyDataTable[j][temp]==DataValueWeight[i].AttriValueName) { int pos=count+begin; for(int k=0;k<6;k++) DataTable[pos][k]=CopyDataTable[j][k]; count++; } position[countpos]=count+begin; countpos++; } /*cout<<"************* Show Result Second ****************"<<endl; cout<<InfoResult[temp].AttriName<<endl; for(int i=begin;i<=end;i++) { for(int j=0;j<=5;j++) cout<<DataTable[i][j]<<" "; cout<<endl; } cout<<" ";*/ return countpos; } void BuildTree(int begin,int end,Node *parent) { int CountP=0,CountN=0; CountInfoNP(begin,end,CountP,CountN); cout<<"************************ The data be sorted **************************"<<endl; for(int i=begin;i<=end;i++) { for(int j=0;j<=5;j++) cout<<DataTable[i][j]<<" "; cout<<endl; } cout<<" "; cout<<parent->AttriName<<" have a look: "<<CountP<<endl; if(CountP==0||CountN==0) //该子集当中只包含Yes或者No时为叶子节点,返回调用处; { cout<<"creat leaf node"<<endl; Node* t=new Node(); //建立叶子节点 if(CountP==0) t->AttriName="No"; else t->AttriName="Yes"; parent->Children.push_back(t); //插入孩子节点 return; } else { double p=(double)CountP/(CountP+CountN); double InfoH=ComputInfo(p); //获得信息熵 for(int k=1;k<5;k++) //循环计算各个属性的条件信息熵,并计算出互信息 { int KindOfValue=SearchData(begin,end,k); int sum=1+end-begin; for(int j=0;j<KindOfValue;j++) //计算出属性的每种取值的权重的倒数 DataValueWeight[j].ValueWeight=DataValueWeight[j].ValueWeight/sum; double InfoGain=0; if(DataValueWeight[0].ValueNWeight!=0&&DataValueWeight[0].ValuePWeight!=0) InfoGain=DataValueWeight[0].ValueWeight*(ComputLog(DataValueWeight[0].ValueNWeight)/DataValueWeight[0].ValueNWeight+ComputLog(DataValueWeight[0].ValuePWeight)/DataValueWeight[0].ValuePWeight); for(int j=1;j<KindOfValue;j++) //计算条件信息 if(DataValueWeight[j].ValueNWeight!=0&&DataValueWeight[j].ValuePWeight!=0) InfoGain+=DataValueWeight[j].ValueWeight*(ComputLog(DataValueWeight[j].ValueNWeight)/DataValueWeight[j].ValueNWeight+ComputLog(DataValueWeight[j].ValuePWeight)/DataValueWeight[j].ValuePWeight); InfoResult[k].AttriI=InfoH-InfoGain; //计算互信息 InfoResult[k].AttriKind=KindOfValue; } int temp=PickAttri(); //选出互信息最大的属性作为节点建树 Node* t=new Node(); t->AttriName=InfoResult[temp].AttriName; SearchData(begin,end,temp); for(int k=0;k<InfoResult[temp].AttriKind;k++) { string name=DataValueWeight[k].AttriValueName; t->AttriValue.push_back(name); } t->parent=parent; parent->Children.push_back(t); //孩子节点压入vector当中 int position[NUMOFPOS]; cout<<"before SortByAttriValue Begin: "<<begin<<",END: "<<end<<endl; SortByAttriValue(begin,end,temp,position); //将数据按照选定属性的取值不同进行划分 int times=InfoResult[temp].AttriKind; for(int l=0;l<=times;l++) cout<<position[l]<<" "; cout<<endl; for(int k=0;k<times;k++) { int head,rear; head=position[k]; int hire=k+1; rear=position[hire]-1; for(int l=0;l<=times;l++) cout<<position[l]<<" "; cout<<endl; cout<<"Head: "<<head<<" ,Rear: "<<rear<<endl; BuildTree(head,rear,t); } } } void ShowTree(Node *root) { if(root->AttriName=="Yes"||root->AttriName=="No") { cout<<root->AttriName<<endl; return; } else { cout<<root->AttriName<<endl; for(vector<string>::iterator itvalue=root->AttriValue.begin();itvalue!=root->AttriValue.end();itvalue++) { string value=*itvalue; cout<<value<<" "; } cout<<endl; for(vector<Node*>::iterator itnode=root->Children.begin();itnode!=root->Children.end();itnode++) { Node *t=*itnode; ShowTree(t); } } } int main() { InfoResult[1].AttriName="天气"; InfoResult[2].AttriName="气温"; InfoResult[3].AttriName="湿度"; InfoResult[4].AttriName="风"; ReadData(); Node *Root=new Node; BuildTree(0,NUM-1,Root); //vector<Node>::iterator it=Root.Children.begin(); ShowTree(Root); /*Node t=*it; cout<<t.AttriName<<endl; for(vector<string>::iterator itvalue=t.AttriValue.begin();itvalue!=t.AttriValue.end();itvalue++) cout<<*itvalue<<endl; it=t.Children.begin(); t=*it; cout<<t.AttriName<<endl;*/ //ShowTree(t); //cout<<"Root: "<<t.AttriName<<" ,Value: "<<*(t.AttriValue.begin())<<endl; return 0; }