线性回归
- 需求:从文件读取数据对,计算回归函数及系数
- 实现1:commons.math的SimpleRegression,定义函数getData从文件读取数据返回SimpleRegression类
1 import java.io.File; 2 import java.io.FileNotFoundException; 3 import java.util.Scanner; 4 import org.apache.commons.math3.stat.regression.SimpleRegression; 5 6 public class Example1 { 7 public static void main(String[] args) { 8 SimpleRegression sr = getData("data/Data1.dat"); 9 double m = sr.getSlope(); 10 double b = sr.getIntercept(); 11 double r = sr.getR(); // correlation coefficient 12 double r2 = sr.getRSquare(); 13 double sse = sr.getSumSquaredErrors(); 14 double tss = sr.getTotalSumSquares(); 15 16 System.out.printf("y = %.6fx + %.4f%n", m, b); 17 System.out.printf("r = %.6f%n", r); 18 System.out.printf("r2 = %.6f%n", r2); 19 System.out.printf("EV = %.5f%n", tss - sse); 20 System.out.printf("UV = %.4f%n", sse); 21 System.out.printf("TV = %.3f%n", tss); 22 } 23 24 public static SimpleRegression getData(String data) { 25 SimpleRegression sr = new SimpleRegression(); 26 try { 27 Scanner fileScanner = new Scanner(new File(data)); 28 fileScanner.nextLine(); // read past title line 29 int n = fileScanner.nextInt(); 30 fileScanner.nextLine(); // read past line of labels 31 fileScanner.nextLine(); // read past line of labels 32 for (int i = 0; i < n; i++) { 33 String line = fileScanner.nextLine(); 34 Scanner lineScanner = new Scanner(line).useDelimiter("\t"); 35 double x = lineScanner.nextDouble(); 36 double y = lineScanner.nextDouble(); 37 sr.addData(x, y); 38 } 39 } catch (FileNotFoundException e) { 40 System.err.println(e); 41 } 42 return sr; 43 } 44 }
- 实现2:直接计算统计量
1 import java.io.File; 2 import java.io.FileNotFoundException; 3 import java.util.Scanner; 4 5 public class Example2 { 6 private static double sX=0, sXX=0, sY=0, sYY=0, sXY=0; 7 private static int n=0; 8 9 public static void main(String[] args) { 10 getData("data/Data1.dat"); 11 double m = (n*sXY - sX*sY)/(n*sXX - sX*sX); 12 double b = sY/n - m*sX/n; 13 double r2 = m*m*(n*sXX - sX*sX)/(n*sYY - sY*sY); 14 double r = Math.sqrt(r2); 15 double tv = sYY - sY*sY/n; 16 double mX = sX/n; // mean value of x 17 double ev = (sXX - 2*mX*sX + n*mX*mX)*m*m; 18 double uv = tv - ev; 19 20 System.out.printf("y = %.6fx + %.4f%n", m, b); 21 System.out.printf("r = %.6f%n", r); 22 System.out.printf("r2 = %.6f%n", r2); 23 System.out.printf("EV = %.5f%n", ev); 24 System.out.printf("UV = %.4f%n", uv); 25 System.out.printf("TV = %.3f%n", tv); 26 } 27 28 public static void getData(String data) { 29 try { 30 Scanner fileScanner = new Scanner(new File(data)); 31 fileScanner.nextLine(); // read past title line 32 n = fileScanner.nextInt(); 33 fileScanner.nextLine(); // read past line of labels 34 fileScanner.nextLine(); // read past line of labels 35 for (int i = 0; i < n; i++) { 36 String line = fileScanner.nextLine(); 37 Scanner lineScanner = new Scanner(line).useDelimiter("\t"); 38 double x = lineScanner.nextDouble(); 39 double y = lineScanner.nextDouble(); 40 sX += x; 41 sXX += x*x; 42 sY += y; 43 sYY += y*y; 44 sXY += x*y; 45 } 46 } catch (FileNotFoundException e) { 47 System.err.println(e); 48 } 49 } 50 }
y = 0.882279x + 18.8739
r = 0.935222
r2 = 0.874641
EV = 1423.35676
UV = 204.0042
TV = 1627.361
- 实现3:对辅助类进行实例化,并绘图
Example3.java
1 import java.io.File; 2 import javax.swing.JFrame; 3 4 public class Example3 { 5 public static void main(String[] args) { 6 Data data = new Data(new File("data/Data1.dat")); 7 JFrame frame = new JFrame(data.getTitle()); 8 frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 9 RegressionPanel panel = new RegressionPanel(data); 10 frame.add(panel); 11 frame.pack(); 12 frame.setSize(500, 422); 13 frame.setResizable(false); 14 frame.setLocationRelativeTo(null); // center frame on screen 15 frame.setVisible(true); 16 } 17 }
Data.java
1 import java.io.File; 2 import java.io.FileNotFoundException; 3 import java.util.Scanner; 4 5 public class Data { 6 private String title,xName, yName; 7 private int n; 8 private double[] x, y; 9 private double sX, sXX, sY, sYY, sXY, minX, minY, maxX, maxY; 10 private double meanX, meanY, slope, intercept, corrCoef; 11 12 public Data(File inputFile) { 13 try { 14 Scanner input = new Scanner(inputFile); 15 title = input.nextLine(); 16 n = input.nextInt(); 17 xName = input.next(); 18 yName = input.next(); 19 input.nextLine(); 20 x = new double[n]; 21 y = new double[n]; 22 minX = minY = Double.POSITIVE_INFINITY; 23 maxX = maxY = Double.NEGATIVE_INFINITY; 24 for (int i = 0; i < n; i++) { 25 double xi = x[i] = input.nextDouble(); 26 double yi = y[i] = input.nextDouble(); 27 sX += xi; 28 sXX += xi*xi; 29 sY += yi; 30 sYY += yi*yi; 31 sXY += xi*yi; 32 minX = (xi < minX? xi: minX); 33 minY = (yi < minY? yi: minY); 34 maxX = (xi > maxX? xi: maxX); 35 maxY = (yi > maxY? yi: maxY); 36 } 37 meanX = sX/n; 38 meanY = sY/n; 39 slope = (n*sXY - sX*sY)/(n*sXX - sX*sX); 40 intercept = meanY - slope*meanX; 41 corrCoef = slope*Math.sqrt((n*sXX - sX*sX)/(n*sYY - sY*sY)); 42 } catch (FileNotFoundException e) { 43 System.err.println(e); 44 } 45 } 46 47 public String getTitle() { 48 return title; 49 } 50 51 public String getXName() { 52 return xName; 53 } 54 55 public String getYName() { 56 return yName; 57 } 58 59 public int getN() { 60 return n; 61 } 62 63 public double[] getX() { 64 return x; 65 } 66 67 public double[] getY() { 68 return y; 69 } 70 71 public double getMeanX() { 72 return meanX; 73 } 74 75 public double getMeanY() { 76 return meanY; 77 } 78 79 public double getSlope() { 80 return slope; 81 } 82 83 public double getIntercept() { 84 return intercept; 85 } 86 87 public double getCorrCoef() { 88 return corrCoef; 89 } 90 91 public double[][] getTable() { 92 double[][] table = new double[n][2]; 93 for (int i = 0; i < n; i++) { 94 table[i][0] = x[i]; 95 table[i][1] = y[i]; 96 } 97 return table; 98 } 99 100 public double getMinX() { 101 return minX; 102 } 103 104 public double getMinY() { 105 return minY; 106 } 107 108 public double getMaxX() { 109 return maxX; 110 } 111 112 public double getMaxY() { 113 return maxY; 114 } 115 }
RegressionPanal.java
import java.awt.BasicStroke; import java.awt.Color; import java.awt.Graphics; import java.awt.Graphics2D; import javax.swing.JPanel; public class RegressionPanel extends JPanel { private static final int WIDTH=500, HEIGHT=400, BUFFER=28, MARGIN=40; private final Data data; private double xMin, xMax, yMin, yMax, xRange, yRange, gWidth, gHeight; private double slope, intercept; public RegressionPanel(Data data) { this.data = data; this.setSize(WIDTH, HEIGHT); this.xMin = data.getMinX(); this.xMax = data.getMaxX(); this.yMin = data.getMinY(); this.yMax = data.getMaxY(); this.slope = data.getSlope(); this.intercept = data.getIntercept(); this.xRange = xMax - xMin; this.yRange = yMax - yMin; this.gWidth = WIDTH - 2*MARGIN - BUFFER; this.gHeight = HEIGHT - 2*MARGIN - BUFFER; setBackground(Color.WHITE); } @Override public void paintComponent(Graphics g) { super.paintComponent(g); Graphics2D g2 = (Graphics2D)g; g2.setStroke(new BasicStroke(1)); drawGrid(g2); drawPoints(g2, data.getX(), data.getY()); drawLine(g2); } private void drawGrid(Graphics2D g2) { g2.setStroke(new BasicStroke(1)); double xGd = Math.pow(10, Math.floor(Math.log10(xRange))); int xd = dToI(xGd); int x0 = dToI(xGd*Math.floor(xMin/xGd)); int xn = dToI(xGd*Math.ceil(xMax/xGd)); for (int xi = x0; xi <= xn; xi += xd) { g2.setColor(Color.LIGHT_GRAY); int p = f(xi); g2.drawLine(p, 0, p, HEIGHT-18); // vertical lines g2.setColor(Color.BLACK); g2.drawString(""+xi, p-8, HEIGHT-4); } double yGd = Math.pow(10, Math.floor(Math.log10(yRange))); int yd = dToI(yGd); int y0 = dToI(xGd*Math.floor(xMin/yGd)); int yn = dToI(xGd*Math.ceil(yMax/yGd)); for (int yi = y0; yi <= yn; yi += yd) { g2.setColor(Color.LIGHT_GRAY); int q = g(yi); g2.drawLine(BUFFER, q, WIDTH, q); // horizontal lines g2.setColor(Color.LIGHT_GRAY); g2.setColor(Color.BLACK); g2.drawString((yi<100?" ":"")+yi, 2, q+5); } } private void drawPoints(Graphics2D g2, double[] x, double[] y) { g2.setColor(Color.BLACK); for (int i = 0; i < x.length; i++) { int u = f(x[i]); int v = g(y[i]); g2.fillOval(u-3, v-3, 6, 6); // coordinates are at NW corners } } private void drawLine(Graphics2D g2) { g2.setColor(Color.BLUE); g2.setStroke(new BasicStroke(2)); int p0 = BUFFER; int q0 = g(yLine(fInv(p0))); int p1 = WIDTH; int q1 = g(yLine(fInv(p1))); g2.drawLine(p0, q0, p1, q1); } private double yLine(double x) { return slope*x + intercept; } private int dToI(double x) { return (int)Math.round(x); } private int f(double x) { return dToI((x - xMin)*gWidth/xRange) + BUFFER + MARGIN; } private int g(double y) { return dToI(gHeight - (y - yMin)*gHeight/yRange) + MARGIN; } private double fInv(int p) { return (p - BUFFER - MARGIN)*xRange/gWidth + xMin; } private double gInv(int q) { return yMin + (gHeight + MARGIN - q)*yRange/gHeight; } }
多项式回归
- 需求:已知刹车速度和距离的数据,求解
- 实现:最小二乘法,解方程组,LU分解
1 import org.apache.commons.math3.linear.*; 2 3 public class Example4 { 4 static double[] x = {20, 30, 40, 50, 60, 70}; 5 static double[] y = {52, 87, 136, 203, 290, 394}; 6 static int n = y.length; // 6 7 8 public static void main(String[] args) { 9 double[][] a = new double[3][3]; 10 double[] w = new double[3]; 11 deriveNormalEquations(a, w); 12 printNormalEquations(a, w); 13 double[] b = solveNormalEquations(a, w); 14 printResults(b); 15 } 16 17 public static void deriveNormalEquations(double[][] a, double[] w) { 18 for (int i = 0; i < n; i++) { 19 double xi = x[i]; 20 double yi = y[i]; 21 a[0][0] = n; 22 a[0][1] = a[1][0] += xi; 23 a[0][2] = a[1][1] = a[2][0] += xi*xi; 24 a[1][2] = a[2][1] += xi*xi*xi; 25 a[2][2] += xi*xi*xi*xi; 26 w[0] += yi; 27 w[1] += xi*yi; 28 w[2] += xi*xi*yi; 29 } 30 } 31 32 public static void printNormalEquations(double[][] a, double[] w) { 33 for (int i = 0; i < 3; i++) { 34 System.out.printf("%8.0fb0 + %6.0fb1 + %8.0fb2 = %7.0f%n", 35 a[i][0], a[i][1], a[i][2], w[i]); 36 } 37 } 38 39 /* Solves the matrix equation a*b = w for b[], representing a[] 40 as RealMatrix m and b[] as RealVector v: 41 */ 42 private static double[] solveNormalEquations(double[][] a, double[] w) { 43 RealMatrix m = new Array2DRowRealMatrix(a, false); 44 LUDecomposition lud = new LUDecomposition(m); 45 DecompositionSolver solver = lud.getSolver(); 46 RealVector v = new ArrayRealVector(w, false); 47 return solver.solve(v).toArray(); 48 } 49 50 private static void printResults(double[] b) { 51 System.out.printf("f(t) = %.2f + %.3ft + %.5ft^2%n", b[0], b[1], b[2]); 52 System.out.printf("f(55) = %.1f%n", f(55, b)); 53 } 54 55 private static double f(double t, double[] b) { 56 return b[0] + b[1]*t + b[2]*t*t; 57 } 58 }
6b0 + 270b1 + 13900b2 = 1162
270b0 + 13900b1 + 783000b2 = 64220
13900b0 + 783000b1 + 46750000b2 = 3798800
f(t) = 40.73 + -1.170t + 0.08875t^2
f(55) = 244.8
多元线性回归
- 需求:变量y依赖于多个变量
- 实现:直接求解或通过Apache Commons
Example5.java
1 import org.apache.commons.math3.linear.*; 2 3 public class Example5 { 4 static double[] x = {10, 9, 12, 10, 9, 10, 8, 11}; 5 static double[] y = {59, 57, 61, 52, 48, 55, 51, 62}; 6 static double[] z = {71, 68, 76, 56, 57, 77, 55, 67}; 7 static int n = z.length; // 8 8 9 public static void main(String[] args) { 10 double[][] a = new double[3][3]; 11 double[] w = new double[3]; 12 deriveNormalEquations(a, w); 13 printNormalEquations(a, w); 14 double[] b = solveNormalEquations(a, w); 15 printResults(b); 16 } 17 18 public static void deriveNormalEquations(double[][] a, double[] w) { 19 for (int i = 0; i < n; i++) { 20 double xi = x[i]; 21 double yi = y[i]; 22 double zi = z[i]; 23 a[0][0] = n; 24 a[0][1] = a[1][0] += xi; 25 a[0][2] = a[2][0] += yi; 26 a[1][1] += xi*xi; 27 a[1][2] = a[2][1] += xi*yi; 28 a[2][2] += yi*yi; 29 w[0] += zi; 30 w[1] += xi*zi; 31 w[2] += yi*zi; 32 } 33 } 34 35 public static void printNormalEquations(double[][] a, double[] w) { 36 for (int i = 0; i < 3; i++) { 37 System.out.printf("%6.0fx0 + %4.0fx1 + %5.0fx2 = %5.0f%n", 38 a[i][0], a[i][1], a[i][2], w[i]); 39 } 40 } 41 42 private static double[] solveNormalEquations(double[][] a, double[] w) { 43 RealMatrix m = new Array2DRowRealMatrix(a, false); 44 LUDecomposition lud = new LUDecomposition(m); 45 DecompositionSolver solver = lud.getSolver(); 46 RealVector v = new ArrayRealVector(w, false); 47 return solver.solve(v).toArray(); 48 } 49 50 private static void printResults(double[] b) { 51 System.out.printf("f(s, t) = %.2f + %.2fs + %.2ft%n", b[0], b[1], b[2]); 52 System.out.printf("f(10, 59) = %.1f%n", f(10, 59, b)); 53 System.out.printf("f(9, 57) = %.1f%n", f(9, 57, b)); 54 System.out.printf("f(11, 64) = %.1f%n", f(11, 64, b)); 55 } 56 57 private static double f(double s, double t, double[] b) { 58 return b[0] + b[1]*s + b[2]*t; 59 } 60 }
Example6.java
1 import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression; 2 3 public class Example6 { 4 static double[][] x = { {10, 59}, {9, 57}, {12, 61}, {10, 52}, {9, 48}, 5 {10, 55}, {8, 51}, {11, 62} }; 6 static double[] y = {71, 68, 76, 56, 57, 77, 55, 67}; 7 8 public static void main(String[] args) { 9 OLSMultipleLinearRegression mlr = new OLSMultipleLinearRegression(); 10 mlr.newSampleData(y, x); 11 double[] b = mlr.estimateRegressionParameters(); 12 printResults(b); 13 } 14 15 private static void printResults(double[] b) { 16 System.out.printf("f(s, t) = %.2f + %.2fs + %.2ft%n", b[0], b[1], b[2]); 17 System.out.printf("f(10, 59) = %.1f%n", f(10, 59, b)); 18 System.out.printf("f(9, 57) = %.1f%n", f(9, 57, b)); 19 System.out.printf("f(11, 64) = %.1f%n", f(11, 64, b)); 20 } 21 22 private static double f(double s, double t, double[] b) { 23 return b[0] + b[1]*s + b[2]*t; 24 } 25 }
8x0 + 79x1 + 445x2 = 527
79x0 + 791x1 + 4427x2 = 5254
445x0 + 4427x1 + 24929x2 = 29543
f(s, t) = -5.75 + 1.55s + 1.01t
f(10, 59) = 69.5
f(9, 57) = 65.9
f(11, 64) = 76.1