这里只写一下用C++简单实现的ID3算法决策树
ID3算法是基于信息熵和信息获取量
每次建立新节点时,选取一个信息获取量最大(以信息熵为衡量)的属性进行分割
决策树还有很多其他算法,不过都只是衡量标准不同
实质都是按照贪心自上而下地建树
如果深度过深,还要采取剪枝的手段
#include <iostream> #include <cstdio> #include <cstring> #include <vector> #include <cmath> using namespace std; typedef unsigned int ui; typedef vector< vector<int>> dv; const int maxm = 100, maxn = 1000; const double eps = 1e-7; struct Node { bool flag[maxm]; int st, yes, no; }node[maxn]; //结点,flag表示已采用的属性,st为此次划分的标准 double cal_entropy(double p) //计算信息熵 { if(abs(p) <= eps || abs(p-1) <= eps) return 0; return -(p*log(p)/log(2)+(1-p)*log(1-p)/log(2)); } double split(dv v, int k) //算出如果以第k个属性分割得到的信息获取量 { int v1, v2, n1, n2; v1 = v2 = n1 = n2 = 0; for(ui i = 0; i < v.size(); i++) { if(v[i][k]) { n1++; if(v[i][v[i].size()-1]) v1++; } else { n2++; if(v[i][v[i].size()-1]) v2++; } } int n = n1+n2; double ans = (double)n1/n*cal_entropy((double)v1/n1) + (double)n2/n*cal_entropy((double)v2/n2); return cal_entropy((double)(v1+v2)/n) - ans; } void build(int x, dv vnode) //按照贪心算法建树 { double ans = -1; int k = -1; for(ui i = 0; i < vnode.size(); i++) if(vnode[i][vnode[i].size()-1]) node[x].yes++; node[x].no = vnode.size() - node[x].yes; for(ui i = 0; i < vnode[0].size()-1; i++) if(!node[x].flag[i] && (split(vnode, i) - ans > eps)) { ans = split(vnode, i); k = i; } node[x].st = k; printf("%d %d %d %d ", x, node[x].yes, node[x].no, node[x].st); //先序遍历输出树的结构 if(k == -1) return; dv v1, v2; for(ui i = 0; i < vnode.size(); i++) if(vnode[i][k]) v1.push_back(vnode[i]); else v2.push_back(vnode[i]); for(ui i = 0; i < v1[0].size(); i++) { node[x*2].flag[i] = node[x].flag[i]; node[x*2+1].flag[i] = node[x].flag[i]; } node[x*2].flag[k] = node[x*2+1].flag[k] = 1; build(x*2, v1); build(x*2+1, v2); } int n, m, x; dv v; int dfs(int x, vector<int> vv) //用于测试集 { if(node[x].st == -1) return node[x].yes > node[x].no; if(vv[node[x].st]) return dfs(2*x, vv); else return dfs(2*x+1, vv); } vector <int> vv; int main() { freopen("a.txt", "r", stdin); cin>>n>>m; v.resize(n); for(int i = 0; i < n; i++) for(int j = 0; j < m; j++) { cin>>x; v[i].push_back(x); } build(1, v); for(int i = 0; i < m; i++) cin>>x, vv.push_back(x); cout<<dfs(1, vv)<<endl; }