ID3算法
- 思路:分类算法的输入为训练集,输出为对数据进行分类的函数。ID3算法为分类函数生成分类树
- 需求:对水果训练集的一个维度(是否甜)进行预测
- 实现:决策树,熵函数,ID3,weka库 J48类
ComputeGain.java
1 public class ComputeGain { 2 public static void main(String[] args) { 3 System.out.printf("h(11,16) = %.4f%n", h(11,16)); 4 System.out.println("Gain(Size):"); 5 System.out.printf(" h(3,5) = %.4f%n", h(3,5)); 6 System.out.printf(" h(6,7) = %.4f%n", h(6,7)); 7 System.out.printf(" h(2,4) = %.4f%n", h(2,4)); 8 System.out.printf(" g({3,6,2},{5,7,4}) = %.4f%n", 9 g(new int[]{3,6,2},new int[]{5,7,4})); 10 System.out.println("Gain(Color):"); 11 System.out.printf(" h(3,4) = %.4f%n", h(3,4)); 12 System.out.printf(" h(3,5) = %.4f%n", h(3,5)); 13 System.out.printf(" h(2,3) = %.4f%n", h(2,3)); 14 System.out.printf(" h(2,4) = %.4f%n", h(2,4)); 15 System.out.printf(" g({3,3,2,2},{4,5,3,4}) = %.4f%n", 16 g(new int[]{3,3,2,2},new int[]{4,5,3,4})); 17 System.out.println("Gain(Surface):"); 18 System.out.printf(" h(4,7) = %.4f%n", h(4,7)); 19 System.out.printf(" h(4,6) = %.4f%n", h(4,6)); 20 System.out.printf(" h(3,3) = %.4f%n", h(3,3)); 21 System.out.printf(" g({4,4,3},{7,6,3}) = %.4f%n", 22 g(new int[]{4,4,3},new int[]{7,6,3})); 23 System.out.println("Gain(Size|SMOOTH):"); 24 System.out.printf(" h(1,3) = %.4f%n", h(1,3)); 25 System.out.printf(" h(3,3) = %.4f%n", h(3,3)); 26 System.out.printf(" g({1,3,0},{3,3,1}) = %.4f%n", 27 g(new int[]{1,3,0},new int[]{3,3,1})); 28 System.out.println("Gain(Color|SMOOTH):"); 29 System.out.printf(" h(2,3) = %.4f%n", h(2,3)); 30 System.out.printf(" g({2,2,0},{3,2,2}) = %.4f%n", 31 g(new int[]{2,2,0},new int[]{3,2,2})); 32 System.out.println("Gain(Size|ROUGH):"); 33 System.out.printf(" h(3,6) = %.4f%n", h(3,6)); 34 System.out.printf(" h(1,2) = %.4f%n", h(1,2)); 35 System.out.printf(" g({2,1,1},{2,2,2}) = %.4f%n", 36 g(new int[]{2,1,1},new int[]{2,2,2})); 37 System.out.println("Gain(Color|ROUGH):"); 38 System.out.printf(" h(4,6) = %.4f%n", h(4,6)); 39 System.out.printf(" g({1,1,1},{2,2,2}) = %.4f%n", 40 g(new int[]{1,0,2,1},new int[]{1,2,2,1})); 41 } 42 43 /* Gain for the splitting {A1, A2, ...}, where Ai 44 has n[i] points, m[i] of which are favorable. 45 */ 46 public static double g(int[] m, int[] n) { 47 int sm = 0, sn = 0; 48 double nsh = 0.0; 49 for (int i = 0; i < m.length; i++) { 50 sm += m[i]; 51 sn += n[i]; 52 nsh += n[i]*h(m[i],n[i]); 53 } 54 return h(sm, sn) - nsh/sn; 55 } 56 57 /* Entropy for m favorable items out of n. 58 */ 59 public static double h(int m, int n) { 60 if (m == 0 || m == n) { 61 return 0; 62 } 63 double p = (double)m/n, q = 1 - p; 64 return -p*lg(p) - q*lg(q); 65 } 66 67 /* Returns the binary logarithm of x. 68 */ 69 public static double lg(double x) { 70 return Math.log(x)/Math.log(2); 71 } 72 }
h(11,16) = 0.8960
Gain(Size):
h(3,5) = 0.9710
h(6,7) = 0.5917
h(2,4) = 1.0000
g({3,6,2},{5,7,4}) = 0.0838
Gain(Color):
h(3,4) = 0.8113
h(3,5) = 0.9710
h(2,3) = 0.9183
h(2,4) = 1.0000
g({3,3,2,2},{4,5,3,4}) = 0.0260
Gain(Surface):
h(4,7) = 0.9852
h(4,6) = 0.9183
h(3,3) = 0.0000
g({4,4,3},{7,6,3}) = 0.1206
Gain(Size|SMOOTH):
h(1,3) = 0.9183
h(3,3) = 0.0000
g({1,3,0},{3,3,1}) = 0.5917
Gain(Color|SMOOTH):
h(2,3) = 0.9183
g({2,2,0},{3,2,2}) = 0.5917
Gain(Size|ROUGH):
h(3,6) = 1.0000
h(1,2) = 1.0000
g({2,1,1},{2,2,2}) = 0.2516
Gain(Color|ROUGH):
h(4,6) = 0.9183
g({1,1,1},{2,2,2}) = 0.9183
1 import weka.classifiers.trees.J48; 2 import weka.core.Instances; 3 import weka.core.Instance; 4 import weka.core.converters.ConverterUtils.DataSource; 5 6 public class TestWekaJ48 { 7 public static void main(String[] args) throws Exception { 8 DataSource source = new DataSource("data/AnonFruit.arff"); 9 Instances instances = source.getDataSet(); 10 instances.setClassIndex(3); // target attribute: (Sweet) 11 12 J48 j48 = new J48(); // an extension of ID3 13 j48.setOptions(new String[]{"-U"}); // use unpruned tree 14 j48.buildClassifier(instances); 15 16 for (Instance instance : instances) { 17 double prediction = j48.classifyInstance(instance); 18 System.out.printf("%4.0f%4.0f%n", 19 instance.classValue(), prediction); 20 } 21 } 22 }
1 1
1 1
1 1
1 0
1 1
0 0
1 1
0 0
0 0
0 0
1 1
1 1
1 1
1 1
0 0
1 1
贝叶斯分类
- 思路:基于训练集计算的比率生成的函数进行分类
Fruit.java
1 import java.io.File; 2 import java.io.FileNotFoundException; 3 import java.util.HashSet; 4 import java.util.Scanner; 5 import java.util.Set; 6 7 public class Fruit { 8 String name, size, color, surface; 9 boolean sweet; 10 11 public Fruit(String name, String size, String color, String surface, 12 boolean sweet) { 13 this.name = name; 14 this.size = size; 15 this.color = color; 16 this.surface = surface; 17 this.sweet = sweet; 18 } 19 20 @Override 21 public String toString() { 22 return String.format("%-12s%-8s%-8s%-8s%s", 23 name, size, color, surface, (sweet? "T": "F") ); 24 } 25 26 public static Set<Fruit> loadData(File file) { 27 Set<Fruit> fruits = new HashSet(); 28 try { 29 Scanner input = new Scanner(file); 30 for (int i = 0; i < 7; i++) { // read past metadata 31 input.nextLine(); 32 } 33 while (input.hasNextLine()) { 34 String line = input.nextLine(); 35 Scanner lineScanner = new Scanner(line); 36 String name = lineScanner.next(); 37 String size = lineScanner.next(); 38 String color = lineScanner.next(); 39 String surface = lineScanner.next(); 40 boolean sweet = (lineScanner.next().equals("T")); 41 Fruit fruit = new Fruit(name, size, color, surface, sweet); 42 fruits.add(fruit); 43 } 44 } catch (FileNotFoundException e) { 45 System.err.println(e); 46 } 47 return fruits; 48 } 49 50 public static void print(Set<Fruit> fruits) { 51 int k=1; 52 for (Fruit fruit : fruits) { 53 System.out.printf("%2d. %s%n", k++, fruit); 54 } 55 } 56 }
BayesianTest.java
1 import java.io.File; 2 import java.util.Set; 3 4 public class BayesianTest { 5 private static Set<Fruit> fruits; 6 7 public static void main(String[] args) { 8 fruits = Fruit.loadData(new File("data/Fruit.arff")); 9 Fruit fruit = new Fruit("cola", "SMALL", "RED", "SMOOTH", false); 10 double n = fruits.size(); // total number of fruits in training set 11 double sum1 = 0; // number of sweet fruits 12 for (Fruit f : fruits) { 13 sum1 += (f.sweet? 1: 0); 14 } 15 double sum2 = n - sum1; // number of sour fruits 16 double[][] p = new double[4][3]; 17 for (Fruit f : fruits) { 18 if (f.sweet) { 19 p[1][1] += (f.size.equals(fruit.size)? 1: 0)/sum1; 20 p[2][1] += (f.color.equals(fruit.color)? 1: 0)/sum1; 21 p[3][1] += (f.surface.equals(fruit.surface)? 1: 0)/sum1; 22 } else { 23 p[1][2] += (f.size.equals(fruit.size)? 1: 0)/sum2; 24 p[2][2] += (f.color.equals(fruit.color)? 1: 0)/sum2; 25 p[3][2] += (f.surface.equals(fruit.surface)? 1: 0)/sum2; 26 } 27 } 28 double pc1 = p[1][1]*p[2][1]*p[3][1]*sum1/n; 29 double pc2 = p[1][2]*p[2][2]*p[3][2]*sum2/n; 30 System.out.printf("pc1 = %.4f, pc2 = %.4f%n", pc1, pc2); 31 System.out.printf("Predict %s is %s.%n", 32 fruit.name, (pc1 > pc2? "sweet": "sour")); 33 } 34 }
pc1 = 0.0186, pc2 = 0.0150
Predict cola is sweet.
TestWekaBayes.java
1 import java.util.List; 2 import weka.classifiers.Evaluation; 3 import weka.classifiers.bayes.NaiveBayes; 4 import weka.classifiers.evaluation.Prediction; 5 import weka.core.Instance; 6 import weka.core.Instances; 7 import weka.core.converters.ConverterUtils; 8 import weka.core.converters.ConverterUtils.DataSource; 9 10 public class TestWekaBayes { 11 public static void main(String[] args) throws Exception { 12 // ConverterUtils.DataSource source = new ConverterUtils.DataSource("data/AnonFruit.arff"); 13 DataSource source = new DataSource("data/AnonFruit.arff"); 14 Instances train = source.getDataSet(); 15 train.setClassIndex(3); // target attribute: (Sweet) 16 //build model 17 NaiveBayes model=new NaiveBayes(); 18 model.buildClassifier(train); 19 20 //use 21 Instances test = train; 22 Evaluation eval = new Evaluation(test); 23 eval.evaluateModel(model,test); 24 List <Prediction> predictions = eval.predictions(); 25 int k = 0; 26 for (Instance instance : test) { 27 double actual = instance.classValue(); 28 double prediction = eval.evaluateModelOnce(model, instance); 29 System.out.printf("%2d.%4.0f%4.0f", ++k, actual, prediction); 30 System.out.println(prediction != actual? " *": ""); 31 } 32 } 33 }
1. 1 1
2. 1 1
3. 1 1
4. 1 1
5. 1 1
6. 0 1 *
7. 1 1
8. 0 0
9. 0 0
10. 0 1 *
11. 1 1
12. 1 1
13. 1 1
14. 1 1
15. 0 0
16. 1 1
SVM算法
- 思路:生成超平面方程,计算数据点位于哪一边
逻辑回归
- 思路:将目标值属性为布尔值的问题转化成一个数值变量,在转化后的问题上进行线性回归
- 需求:某政党候选人想知道选举获胜的花费
- 实现
1 import org.apache.commons.math3.analysis.function.*; 2 import org.apache.commons.math3.stat.regression.SimpleRegression; 3 4 public class LogisticRegression { 5 static int n = 6; 6 static double[] x = {5, 15, 25, 35, 45, 55}; 7 static double[] p = {2./6,2./5, 4./8, 5./9, 3./5, 4./5}; 8 static double[] y = new double[n]; // y = logit(p) 9 10 public static void main(String[] args) { 11 12 // Transform p-values into y-values: 13 Logit logit = new Logit(); 14 for (int i = 0; i < n; i++) { 15 y[i] = logit.value(p[i]); 16 } 17 18 // Set up input array for linear regression: 19 double[][] data = new double[n][n]; 20 for (int i = 0; i < n; i++) { 21 data[i][0] = x[i]; 22 data[i][1] = y[i]; 23 } 24 25 // Run linear regression of y on x: 26 SimpleRegression sr = new SimpleRegression(); 27 sr.addData(data); 28 29 // Print results: 30 for (int i = 0; i < n; i++) { 31 System.out.printf("x = %2.0f, y = %7.4f%n", x[i], sr.predict(x[i])); 32 } 33 System.out.println(); 34 35 // Convert y-values back to p-values: 36 Sigmoid sigmoid = new Sigmoid(); 37 for (int i = 0; i < n; i++) { 38 double p = sr.predict(x[i]); 39 System.out.printf("x = %2.0f, p = %6.4f%n", x[i], sigmoid.value(p)); 40 } 41 } 42 }
x = 5, y = -0.7797
x = 15, y = -0.4067
x = 25, y = -0.0338
x = 35, y = 0.3392
x = 45, y = 0.7121
x = 55, y = 1.0851
x = 5, p = 0.3144
x = 15, p = 0.3997
x = 25, p = 0.4916
x = 35, p = 0.5840
x = 45, p = 0.6709
x = 55, p = 0.7475
k临近
- 思路:根据临近范围内的样本进行分类
1 import weka.classifiers.lazy.IBk; // K-Nearest Neighbors 2 import weka.core.Instances; 3 import weka.core.Instance; 4 import weka.core.converters.ConverterUtils.DataSource; 5 6 public class TestIBk { 7 public static void main(String[] args) throws Exception { 8 DataSource source = new DataSource("data/AnonFruit.arff"); 9 Instances instances = source.getDataSet(); 10 instances.setClassIndex(3); // target attribute: (Sweet) 11 12 IBk ibk = new IBk(); 13 ibk.buildClassifier(instances); 14 15 for (Instance instance : instances) { 16 double prediction = ibk.classifyInstance(instance); 17 System.out.printf("%4.0f%4.0f%n", 18 instance.classValue(), prediction); 19 } 20 } 21 }
1 1
1 1
1 1
1 0
1 1
0 0
1 1
0 0
0 0
0 0
1 1
1 1
1 1
1 1
0 0
1 1