zoukankan      html  css  js  c++  java
  • ID3决策树算法

    基本概念:

    信息熵是信息的一种不确定的程度的度量。假定一个系统s具有概率分布p={pi}(0<=pi<=1),i=1,2,3,4,...,n,则系统s的信息熵定义为。假设X是一个集合,如果存在一组集合A1,A2,A3,...,An,满足下列条件则称A1-An是集合X的一个划分。

    ID3算法使用信息熵作为度量标准,选择信息熵最小的属性作为分类属性,完成决策树的构造,其中属性的熵定义为该属性单个属性值得权熵之和。在生成树的过程中,每个节点只有一个属性值(权熵相同的属性值看成一个属性值)。

    树的递归结束条件是,划分的集合是否属于同一类,或者是否达到了所要求的深度,或者某个类的个数达到了一定的阈值。

    这是我写的一个ID3算法的例子:

    #include<stdio.h>
    #include<math.h>
    #include<string.h>
    #include<stdlib.h>
    #define SHORT	    0
    #define MEDIUM 	1
    #define TALL 	    2
    #define MAIL 	    0
    #define FEMAIL 	1
    #define GENDER	0
    #define HEIGHT	1
    #define KIND 	    2
    #define LEAF       -1
    typedef struct TNODE
    {
    	int attribute;
    	int arriv_value;
    	struct TNODE *child[50];
    	int childCount;
    	int classification;
    } Node;
    int attriCnt[10]={2,6};
    int classCnt=2;
    int trainingData[100][30];
    int testData[100][3];
    double Entropy(int *indexArray/*需要统计元组下标*/,int len/*元组的个数*/)
    {
    	/*
    		1.统计某个属性种类得个数
    		2.使用log计算出值,返回
    	*/
    	double sum=0;
    	int i,j;
    	int cnt[10];
    	memset(cnt,0,sizeof(cnt));
    	for(j=0;j<classCnt;j++)
    	{
    		for(i=0;i<len;i++)
    		{
    			if(trainingData[indexArray[i]][KIND]==j/*等于某个属性值*/)
    			{
    				cnt[j]++;/*该lei值个数+1*/
    			}
    		}
    	}
    	for(i=0;i<classCnt;i++)
    	{
            if(cnt[i]==0)continue;
    		double temp=log(cnt[i]*1.0/len)/log(2);
           // printf("cnt: %d
    ",cnt[i]);
           // printf("log: %lf
    ",log(cnt[i]*1.0/len));
           // printf("temp: %lf
    ",temp);
            sum=sum-cnt[i]*1.0/len*temp;
    	}
        return sum;
    }
    double Grain(int *indexArray,int attri,int len)//每次调用grain的环境可能不一样:indexArray
    {
    	int i,j;
    	double h;
    	double hd=Entropy(indexArray,len);
       // printf("in grain function,hd:%lf
    ",hd);
    	int subIndexArray[10];
    	int sublen;
    	double result=0;
    	for(i=0;i<attriCnt[attri];i++)
    	{
    		sublen=0;
    		for(j=0;j<len;j++)
    		{
    			if(trainingData[indexArray[j]][attri]==i/*如果该属性是某个值*/)
    			{
    				subIndexArray[sublen++]=indexArray[j];/*统计该属性值得个数,记录下标存入数组当中以便计算*/
    			}
    		}
            /*for(j=0;j<sublen;j++)
            {
                printf("%d	",subIndexArray[j]);
            }printf("
    ");*/
    		h=Entropy(subIndexArray,sublen);//计算熵
            //printf("in grain function,h:%lf
    ",h);
    		result=result+sublen*1.0/len*h;
    	}
        result=hd-result;
    	return result;
    }
    int toClass(int *chooseIndex,int lines)
    {
    	int i;
    	int cnt[3];
        cnt[0]=cnt[1]=cnt[2]=0;
       /* for(i=0;i<lines;i++)
        {
            printf("chooseIndex: %d	",chooseIndex[i]);
        }printf("
    ");*/
    	for(i=0;i<lines;i++)
    	{
    		cnt[trainingData[chooseIndex[i]][KIND]]++;
    	}
    	int maxv=-1;
        int flag=0;
        for(i=0;i<3;i++)
    	{
    		if(maxv<cnt[i]){maxv=cnt[i];flag=i;}
    	}
        //printf("maxv: %d
    ",maxv);
       // printf("flag: %d
    ",flag);
    	return flag;
    }
    int check_attribute(int *chooseIndex,int len)//检查所有得元组是否都是一类
    {
        /*
         1.扫描所有得元组,如果出现不适同一类得元组,则返回
        */
        int i;
        for(i=1;i<len;i++)
        {
            if(trainingData[chooseIndex[i]][KIND]!=trainingData[chooseIndex[i-1]][KIND])
            {
                return 0;
            }
        }
        return 1;
    }
    Node *buildTree(int *chooseIndex/*选中的元组*/,int lines/*元组个数*/,int *remain_attribute/*剩下未分类的属性*/,int attriNumber/*属性得个数*/,int arriv_value)
    {
    	//错误:递归结束条件错
        int i,j;
       // printf("attriNumber: %d
    ",attriNumber);
      //  printf("lines: %d
    ",lines);
        /*for(i=0;i<lines;i++)
        {
            printf("chooseIndex: %d	",chooseIndex[i]);
        }printf("
    ");*/
        if(lines==0)return NULL;
        int choose_attribute;
        double maxgrain=-1;
        int flag=check_attribute(chooseIndex,lines);
    	if(flag==1)/*属性相同的时候,停止递归*/
    	{
    		Node *no=(Node *)malloc(sizeof(Node));
    		no->attribute=LEAF;
    		no->childCount=0;
    		no->arriv_value=arriv_value;
    		no->classification=toClass(chooseIndex,lines);
    		for(i=0;i<50;i++)no->child[i]=NULL;
    		return no;
    	}
        else if(attriNumber==1)
        {
            choose_attribute=remain_attribute[0];
        }
    	else
        {
            for(i=0;i<attriNumber;i++)//选中最大得增益值
            {
                double temp=Grain(chooseIndex,remain_attribute[i],lines);
            //   printf("temp: %lf	",temp);
                if(temp>maxgrain)
                {
                    maxgrain=temp;
                    choose_attribute=remain_attribute[i];
                }
            }
            //printf("
    ");
        }
    	/*确定剩下得属性*/
    	int subRemain_attribute[10];
    	int k=0;
    	for(i=0;i<attriNumber;i++)//计算未使用得属性
    	{
    		if(remain_attribute[i]!=choose_attribute)
    		{
    			subRemain_attribute[k++]=remain_attribute[i];
    		}
    	}
    	/*新建节点*/
    	Node *no=(Node *)malloc(sizeof(Node));
    	no->attribute=choose_attribute;
    	no->childCount=attriCnt[choose_attribute];
    	no->arriv_value=arriv_value;
        no->classification=-1;
    	for(i=0;i<50;i++)no->child[i]=NULL;
    	for(i=0;i<attriCnt[choose_attribute];i++)
    	{
    		int subChooseIndex[100];
    		int subLines=0;
    		for(j=0;j<lines;j++)
    		{
    			if(trainingData[chooseIndex[j]][choose_attribute]==i)
    			{
    				subChooseIndex[subLines++]=chooseIndex[j];
    			}
    		}
    		no->child[i]=buildTree(subChooseIndex,subLines,subRemain_attribute,k,i);
    	}
        return no;
    }
    void blank(int deep)
    {
        int i;
        for(i=0;i<deep;i++)printf("		");
    }
    void Triverse(Node *root,int deep)
    {
        if(root==NULL)return;
        int i;
        blank(deep);
        switch (root->attribute)
        {
            case GENDER:printf(" classification:gender
    ");blank(deep);break;
            case HEIGHT:printf("calssification:height
    ");blank(deep);break;
            case LEAF:printf("leaf arrived
    ");blank(deep);break;
            default:printf("%d
    ",root->attribute);blank(deep);
        }
        printf("arriv_value: %d
    ",root->arriv_value);blank(deep);
        printf("childCount: %d
    ",root->childCount);blank(deep);
        printf("classification: %d
    ",root->classification);blank(deep);
        printf("------------------------------------------
    ");
        for(i=0;i<root->childCount;i++)
        {
            Triverse(root->child[i],deep+1);
        }
    }
    void Classify(int lineNumber,Node *root)
    {
        if(root==NULL)
        {
            printf("classify failed!
    ");
            return;
        }
        if(root->child[0]==NULL)//如果到达了叶子节点
        {
            int choice=root->classification;
            switch (choice) {
                case 0:printf("the training data belongs to Short
    ");break;
                case 1:printf("the training data belongs to Medium
    ");break;
                case 2:printf("the training data belongs to Tall
    ");break;
                default: printf("classify failed!
    ");break;
            }
            return;
        }
        int classifyAttribute=root->attribute;
        int childIndex=testData[lineNumber][classifyAttribute];
        Classify(lineNumber,root->child[childIndex]);
    }
    int main()
    {
        FILE *fp;
    	fp=fopen("./data.txt","r");
    	if(fp==NULL)
    	{
    		printf("Can not open file
    ");
    		return 0;
    	}
    	char name[10],kind[10],gender[10];
        double height;
    	int lines=0;
    	while(fscanf(fp,"%s",name)!=EOF)
    	{
            fscanf(fp,"%s",gender);
    		if(!strcmp(gender,"F"))
    		{
    			trainingData[lines][0]=FEMAIL;
    		}
    		else trainingData[lines][0]=MAIL;
    		fscanf(fp,"%lf",&height);
            if(height>=1.6&&height<1.7)
            {
                trainingData[lines][1]=0;
            }
            else if(height>=1.7&&height<1.8)
            {
                trainingData[lines][1]=1;
            }
            else if(height>=1.8&&height<1.9)
            {
                trainingData[lines][1]=2;
            }
            else if(height>=1.9&&height<2.0)
            {
                trainingData[lines][1]=3;
            }
            else if(height>=2.0&&height<2.1)
            {
                trainingData[lines][1]=4;
            }
            else if(height>=2.1&&height<=2.2)
            {
                trainingData[lines][1]=5;
            }
            fscanf(fp,"%s",kind);
            if(!strcmp(kind,"Short"))
            {
                trainingData[lines][2]=SHORT;
            }
            else if(!strcmp(kind,"Medium"))
            {
                trainingData[lines][2]=MEDIUM;
            }
            else trainingData[lines][2]=TALL;
            lines++;
    	}
        //printf("lines: %d
    ",lines);
        int i,j;
       /* for(i=0;i<lines;i++)
        {
            for(j=0;j<=2;j++)printf("%d	",trainingData[i][j]);
            printf("
    ");
        }*/
        fclose(fp);fp=NULL;
        int index[100],remain_attribute[100];
        for(i=0;i<lines;i++)index[i]=i;
        for(i=0;i<2;i++)remain_attribute[i]=i;
        printf("print the decision tree:
    ");
        Node *root=buildTree(index,lines,remain_attribute,2,-1);
        Triverse(root,0);
        printf("The training data is:
    ");
        fp=fopen("./testData.txt","r");
        if(fp==NULL)
        {
            printf("Can not open the file!
    ");
            return 0;
        }
        int testlines=0;
        while(fscanf(fp,"%s",name)!=EOF)
    	{
            printf("%s	",name);
            fscanf(fp,"%s",gender);
            printf("%s	",gender);
    		if(!strcmp(gender,"F"))
    		{
    			testData[testlines][0]=FEMAIL;
    		}
    		else testData[testlines][0]=MAIL;
    		fscanf(fp,"%lf",&height);
            printf("%lf
    ",height);
            if(height>=1.6&&height<1.7)
            {
                testData[testlines][1]=0;
            }
            else if(height>=1.7&&height<1.8)
            {
                testData[testlines][1]=1;
            }
            else if(height>=1.8&&height<1.9)
            {
                testData[testlines][1]=2;
            }
            else if(height>=1.9&&height<2.0)
            {
                testData[testlines][1]=3;
            }
            else if(height>=2.0&&height<2.1)
            {
                testData[testlines][1]=4;
            }
            else if(height>=2.1&&height<=2.2)
            {
                testData[testlines][1]=5;
            }
            testlines++;
    	}
        /*for(i=0;i<testlines;i++)
        {
            for(j=0;j<2;j++)
            {
                printf("%d	",testData[i][j]);
            }printf("
    ");
        }*/
        Classify(0,root);
    }
    
    
  • 相关阅读:
    The Mac Application Environment 不及格的程序员
    Xcode Plugin: Change Code In Running App Without Restart 不及格的程序员
    The property delegate of CALayer cause Crash. 不及格的程序员
    nil localizedTitle in SKProduct 不及格的程序员
    InApp Purchase 不及格的程序员
    Safari Web Content Guide 不及格的程序员
    在Mac OS X Lion 安装 XCode 3.2 不及格的程序员
    illustrate ARC with graphs 不及格的程序员
    Viewing iPhoneOptimized PNGs 不及格的程序员
    What is the dSYM? 不及格的程序员
  • 原文地址:https://www.cnblogs.com/jackwuyongxing/p/3498655.html
Copyright © 2011-2022 走看看