zoukankan      html  css  js  c++  java
  • ID3决策树算法

    基本概念:

    信息熵是信息的一种不确定的程度的度量。假定一个系统s具有概率分布p={pi}(0<=pi<=1),i=1,2,3,4,...,n,则系统s的信息熵定义为。假设X是一个集合,如果存在一组集合A1,A2,A3,...,An,满足下列条件则称A1-An是集合X的一个划分。

    ID3算法使用信息熵作为度量标准,选择信息熵最小的属性作为分类属性,完成决策树的构造,其中属性的熵定义为该属性单个属性值得权熵之和。在生成树的过程中,每个节点只有一个属性值(权熵相同的属性值看成一个属性值)。

    树的递归结束条件是,划分的集合是否属于同一类,或者是否达到了所要求的深度,或者某个类的个数达到了一定的阈值。

    这是我写的一个ID3算法的例子:

    #include<stdio.h>
    #include<math.h>
    #include<string.h>
    #include<stdlib.h>
    #define SHORT	    0
    #define MEDIUM 	1
    #define TALL 	    2
    #define MAIL 	    0
    #define FEMAIL 	1
    #define GENDER	0
    #define HEIGHT	1
    #define KIND 	    2
    #define LEAF       -1
    typedef struct TNODE
    {
    	int attribute;
    	int arriv_value;
    	struct TNODE *child[50];
    	int childCount;
    	int classification;
    } Node;
    int attriCnt[10]={2,6};
    int classCnt=2;
    int trainingData[100][30];
    int testData[100][3];
    double Entropy(int *indexArray/*需要统计元组下标*/,int len/*元组的个数*/)
    {
    	/*
    		1.统计某个属性种类得个数
    		2.使用log计算出值,返回
    	*/
    	double sum=0;
    	int i,j;
    	int cnt[10];
    	memset(cnt,0,sizeof(cnt));
    	for(j=0;j<classCnt;j++)
    	{
    		for(i=0;i<len;i++)
    		{
    			if(trainingData[indexArray[i]][KIND]==j/*等于某个属性值*/)
    			{
    				cnt[j]++;/*该lei值个数+1*/
    			}
    		}
    	}
    	for(i=0;i<classCnt;i++)
    	{
            if(cnt[i]==0)continue;
    		double temp=log(cnt[i]*1.0/len)/log(2);
           // printf("cnt: %d
    ",cnt[i]);
           // printf("log: %lf
    ",log(cnt[i]*1.0/len));
           // printf("temp: %lf
    ",temp);
            sum=sum-cnt[i]*1.0/len*temp;
    	}
        return sum;
    }
    double Grain(int *indexArray,int attri,int len)//每次调用grain的环境可能不一样:indexArray
    {
    	int i,j;
    	double h;
    	double hd=Entropy(indexArray,len);
       // printf("in grain function,hd:%lf
    ",hd);
    	int subIndexArray[10];
    	int sublen;
    	double result=0;
    	for(i=0;i<attriCnt[attri];i++)
    	{
    		sublen=0;
    		for(j=0;j<len;j++)
    		{
    			if(trainingData[indexArray[j]][attri]==i/*如果该属性是某个值*/)
    			{
    				subIndexArray[sublen++]=indexArray[j];/*统计该属性值得个数,记录下标存入数组当中以便计算*/
    			}
    		}
            /*for(j=0;j<sublen;j++)
            {
                printf("%d	",subIndexArray[j]);
            }printf("
    ");*/
    		h=Entropy(subIndexArray,sublen);//计算熵
            //printf("in grain function,h:%lf
    ",h);
    		result=result+sublen*1.0/len*h;
    	}
        result=hd-result;
    	return result;
    }
    int toClass(int *chooseIndex,int lines)
    {
    	int i;
    	int cnt[3];
        cnt[0]=cnt[1]=cnt[2]=0;
       /* for(i=0;i<lines;i++)
        {
            printf("chooseIndex: %d	",chooseIndex[i]);
        }printf("
    ");*/
    	for(i=0;i<lines;i++)
    	{
    		cnt[trainingData[chooseIndex[i]][KIND]]++;
    	}
    	int maxv=-1;
        int flag=0;
        for(i=0;i<3;i++)
    	{
    		if(maxv<cnt[i]){maxv=cnt[i];flag=i;}
    	}
        //printf("maxv: %d
    ",maxv);
       // printf("flag: %d
    ",flag);
    	return flag;
    }
    int check_attribute(int *chooseIndex,int len)//检查所有得元组是否都是一类
    {
        /*
         1.扫描所有得元组,如果出现不适同一类得元组,则返回
        */
        int i;
        for(i=1;i<len;i++)
        {
            if(trainingData[chooseIndex[i]][KIND]!=trainingData[chooseIndex[i-1]][KIND])
            {
                return 0;
            }
        }
        return 1;
    }
    Node *buildTree(int *chooseIndex/*选中的元组*/,int lines/*元组个数*/,int *remain_attribute/*剩下未分类的属性*/,int attriNumber/*属性得个数*/,int arriv_value)
    {
    	//错误:递归结束条件错
        int i,j;
       // printf("attriNumber: %d
    ",attriNumber);
      //  printf("lines: %d
    ",lines);
        /*for(i=0;i<lines;i++)
        {
            printf("chooseIndex: %d	",chooseIndex[i]);
        }printf("
    ");*/
        if(lines==0)return NULL;
        int choose_attribute;
        double maxgrain=-1;
        int flag=check_attribute(chooseIndex,lines);
    	if(flag==1)/*属性相同的时候,停止递归*/
    	{
    		Node *no=(Node *)malloc(sizeof(Node));
    		no->attribute=LEAF;
    		no->childCount=0;
    		no->arriv_value=arriv_value;
    		no->classification=toClass(chooseIndex,lines);
    		for(i=0;i<50;i++)no->child[i]=NULL;
    		return no;
    	}
        else if(attriNumber==1)
        {
            choose_attribute=remain_attribute[0];
        }
    	else
        {
            for(i=0;i<attriNumber;i++)//选中最大得增益值
            {
                double temp=Grain(chooseIndex,remain_attribute[i],lines);
            //   printf("temp: %lf	",temp);
                if(temp>maxgrain)
                {
                    maxgrain=temp;
                    choose_attribute=remain_attribute[i];
                }
            }
            //printf("
    ");
        }
    	/*确定剩下得属性*/
    	int subRemain_attribute[10];
    	int k=0;
    	for(i=0;i<attriNumber;i++)//计算未使用得属性
    	{
    		if(remain_attribute[i]!=choose_attribute)
    		{
    			subRemain_attribute[k++]=remain_attribute[i];
    		}
    	}
    	/*新建节点*/
    	Node *no=(Node *)malloc(sizeof(Node));
    	no->attribute=choose_attribute;
    	no->childCount=attriCnt[choose_attribute];
    	no->arriv_value=arriv_value;
        no->classification=-1;
    	for(i=0;i<50;i++)no->child[i]=NULL;
    	for(i=0;i<attriCnt[choose_attribute];i++)
    	{
    		int subChooseIndex[100];
    		int subLines=0;
    		for(j=0;j<lines;j++)
    		{
    			if(trainingData[chooseIndex[j]][choose_attribute]==i)
    			{
    				subChooseIndex[subLines++]=chooseIndex[j];
    			}
    		}
    		no->child[i]=buildTree(subChooseIndex,subLines,subRemain_attribute,k,i);
    	}
        return no;
    }
    void blank(int deep)
    {
        int i;
        for(i=0;i<deep;i++)printf("		");
    }
    void Triverse(Node *root,int deep)
    {
        if(root==NULL)return;
        int i;
        blank(deep);
        switch (root->attribute)
        {
            case GENDER:printf(" classification:gender
    ");blank(deep);break;
            case HEIGHT:printf("calssification:height
    ");blank(deep);break;
            case LEAF:printf("leaf arrived
    ");blank(deep);break;
            default:printf("%d
    ",root->attribute);blank(deep);
        }
        printf("arriv_value: %d
    ",root->arriv_value);blank(deep);
        printf("childCount: %d
    ",root->childCount);blank(deep);
        printf("classification: %d
    ",root->classification);blank(deep);
        printf("------------------------------------------
    ");
        for(i=0;i<root->childCount;i++)
        {
            Triverse(root->child[i],deep+1);
        }
    }
    void Classify(int lineNumber,Node *root)
    {
        if(root==NULL)
        {
            printf("classify failed!
    ");
            return;
        }
        if(root->child[0]==NULL)//如果到达了叶子节点
        {
            int choice=root->classification;
            switch (choice) {
                case 0:printf("the training data belongs to Short
    ");break;
                case 1:printf("the training data belongs to Medium
    ");break;
                case 2:printf("the training data belongs to Tall
    ");break;
                default: printf("classify failed!
    ");break;
            }
            return;
        }
        int classifyAttribute=root->attribute;
        int childIndex=testData[lineNumber][classifyAttribute];
        Classify(lineNumber,root->child[childIndex]);
    }
    int main()
    {
        FILE *fp;
    	fp=fopen("./data.txt","r");
    	if(fp==NULL)
    	{
    		printf("Can not open file
    ");
    		return 0;
    	}
    	char name[10],kind[10],gender[10];
        double height;
    	int lines=0;
    	while(fscanf(fp,"%s",name)!=EOF)
    	{
            fscanf(fp,"%s",gender);
    		if(!strcmp(gender,"F"))
    		{
    			trainingData[lines][0]=FEMAIL;
    		}
    		else trainingData[lines][0]=MAIL;
    		fscanf(fp,"%lf",&height);
            if(height>=1.6&&height<1.7)
            {
                trainingData[lines][1]=0;
            }
            else if(height>=1.7&&height<1.8)
            {
                trainingData[lines][1]=1;
            }
            else if(height>=1.8&&height<1.9)
            {
                trainingData[lines][1]=2;
            }
            else if(height>=1.9&&height<2.0)
            {
                trainingData[lines][1]=3;
            }
            else if(height>=2.0&&height<2.1)
            {
                trainingData[lines][1]=4;
            }
            else if(height>=2.1&&height<=2.2)
            {
                trainingData[lines][1]=5;
            }
            fscanf(fp,"%s",kind);
            if(!strcmp(kind,"Short"))
            {
                trainingData[lines][2]=SHORT;
            }
            else if(!strcmp(kind,"Medium"))
            {
                trainingData[lines][2]=MEDIUM;
            }
            else trainingData[lines][2]=TALL;
            lines++;
    	}
        //printf("lines: %d
    ",lines);
        int i,j;
       /* for(i=0;i<lines;i++)
        {
            for(j=0;j<=2;j++)printf("%d	",trainingData[i][j]);
            printf("
    ");
        }*/
        fclose(fp);fp=NULL;
        int index[100],remain_attribute[100];
        for(i=0;i<lines;i++)index[i]=i;
        for(i=0;i<2;i++)remain_attribute[i]=i;
        printf("print the decision tree:
    ");
        Node *root=buildTree(index,lines,remain_attribute,2,-1);
        Triverse(root,0);
        printf("The training data is:
    ");
        fp=fopen("./testData.txt","r");
        if(fp==NULL)
        {
            printf("Can not open the file!
    ");
            return 0;
        }
        int testlines=0;
        while(fscanf(fp,"%s",name)!=EOF)
    	{
            printf("%s	",name);
            fscanf(fp,"%s",gender);
            printf("%s	",gender);
    		if(!strcmp(gender,"F"))
    		{
    			testData[testlines][0]=FEMAIL;
    		}
    		else testData[testlines][0]=MAIL;
    		fscanf(fp,"%lf",&height);
            printf("%lf
    ",height);
            if(height>=1.6&&height<1.7)
            {
                testData[testlines][1]=0;
            }
            else if(height>=1.7&&height<1.8)
            {
                testData[testlines][1]=1;
            }
            else if(height>=1.8&&height<1.9)
            {
                testData[testlines][1]=2;
            }
            else if(height>=1.9&&height<2.0)
            {
                testData[testlines][1]=3;
            }
            else if(height>=2.0&&height<2.1)
            {
                testData[testlines][1]=4;
            }
            else if(height>=2.1&&height<=2.2)
            {
                testData[testlines][1]=5;
            }
            testlines++;
    	}
        /*for(i=0;i<testlines;i++)
        {
            for(j=0;j<2;j++)
            {
                printf("%d	",testData[i][j]);
            }printf("
    ");
        }*/
        Classify(0,root);
    }
    
    
  • 相关阅读:
    IDE有毒
    Netbeans 8.2关于PHP的新特性
    什么是人格
    谁该赋予一款产品灵魂?
    自从升级到macOS后,整个人都不好了
    公司不是大家庭
    性能各个指标分析
    Sqlserver2012 alwayson部署攻略
    初探Backbone
    SQL Server AlwaysOn架构及原理
  • 原文地址:https://www.cnblogs.com/jackwuyongxing/p/3498655.html
Copyright © 2011-2022 走看看