zoukankan      html  css  js  c++  java
  • [Java] 数据分析 -- 回归分析

    线性回归

    • 需求:从文件读取数据对,计算回归函数及系数
    • 实现1:commons.math的SimpleRegression,定义函数getData从文件读取数据返回SimpleRegression类
     1 import java.io.File;
     2 import java.io.FileNotFoundException;
     3 import java.util.Scanner;
     4 import org.apache.commons.math3.stat.regression.SimpleRegression;
     5 
     6 public class Example1 {
     7     public static void main(String[] args) {
     8         SimpleRegression sr = getData("data/Data1.dat");
     9         double m = sr.getSlope();
    10         double b = sr.getIntercept();
    11         double r = sr.getR();  // correlation coefficient
    12         double r2 = sr.getRSquare();
    13         double sse = sr.getSumSquaredErrors();
    14         double tss = sr.getTotalSumSquares();
    15 
    16         System.out.printf("y = %.6fx + %.4f%n", m, b);
    17         System.out.printf("r = %.6f%n", r);
    18         System.out.printf("r2 = %.6f%n", r2);
    19         System.out.printf("EV = %.5f%n", tss - sse);
    20         System.out.printf("UV = %.4f%n", sse);
    21         System.out.printf("TV = %.3f%n", tss);
    22     }
    23     
    24     public static SimpleRegression getData(String data) {
    25         SimpleRegression sr = new SimpleRegression();
    26         try {
    27             Scanner fileScanner = new Scanner(new File(data));
    28             fileScanner.nextLine();  // read past title line
    29             int n = fileScanner.nextInt();
    30             fileScanner.nextLine();  // read past line of labels
    31             fileScanner.nextLine();  // read past line of labels
    32             for (int i = 0; i < n; i++) {
    33                 String line = fileScanner.nextLine();
    34                 Scanner lineScanner = new Scanner(line).useDelimiter("\t");
    35                 double x = lineScanner.nextDouble();
    36                 double y = lineScanner.nextDouble();
    37                 sr.addData(x, y);
    38             }
    39         } catch (FileNotFoundException e) {
    40             System.err.println(e);
    41         }
    42         return sr;
    43     }
    44 }
    View Code
    • 实现2:直接计算统计量
     1 import java.io.File;
     2 import java.io.FileNotFoundException;
     3 import java.util.Scanner;
     4 
     5 public class Example2 {
     6     private static double sX=0, sXX=0, sY=0, sYY=0, sXY=0;
     7     private static int n=0;
     8 
     9     public static void main(String[] args) {
    10         getData("data/Data1.dat");
    11         double m = (n*sXY - sX*sY)/(n*sXX - sX*sX);
    12         double b = sY/n - m*sX/n;
    13         double r2 = m*m*(n*sXX - sX*sX)/(n*sYY - sY*sY);
    14         double r = Math.sqrt(r2);
    15         double tv = sYY - sY*sY/n;
    16         double mX = sX/n;  // mean value of x
    17         double ev = (sXX - 2*mX*sX + n*mX*mX)*m*m;
    18         double uv = tv - ev;
    19         
    20         System.out.printf("y = %.6fx + %.4f%n", m, b);
    21         System.out.printf("r = %.6f%n", r);
    22         System.out.printf("r2 = %.6f%n", r2);
    23         System.out.printf("EV = %.5f%n", ev);
    24         System.out.printf("UV = %.4f%n", uv);
    25         System.out.printf("TV = %.3f%n", tv);
    26     }
    27     
    28     public static void getData(String data) {
    29         try {
    30             Scanner fileScanner = new Scanner(new File(data));
    31             fileScanner.nextLine();  // read past title line
    32             n = fileScanner.nextInt();
    33             fileScanner.nextLine();  // read past line of labels
    34             fileScanner.nextLine();  // read past line of labels
    35             for (int i = 0; i < n; i++) {
    36                 String line = fileScanner.nextLine();
    37                 Scanner lineScanner = new Scanner(line).useDelimiter("\t");
    38                 double x = lineScanner.nextDouble();
    39                 double y = lineScanner.nextDouble();
    40                 sX += x;
    41                 sXX += x*x;
    42                 sY += y;
    43                 sYY += y*y;
    44                 sXY += x*y;
    45             }
    46         } catch (FileNotFoundException e) {
    47             System.err.println(e);
    48         }
    49     }
    50 }
    View Code

    y = 0.882279x + 18.8739
    r = 0.935222
    r2 = 0.874641
    EV = 1423.35676
    UV = 204.0042
    TV = 1627.361

    • 实现3:对辅助类进行实例化,并绘图

    Example3.java

     1 import java.io.File;
     2 import javax.swing.JFrame;
     3 
     4 public class Example3 {
     5     public static void main(String[] args) {
     6         Data data = new Data(new File("data/Data1.dat"));
     7         JFrame frame = new JFrame(data.getTitle());
     8         frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
     9         RegressionPanel panel = new RegressionPanel(data);
    10         frame.add(panel);
    11         frame.pack();
    12         frame.setSize(500, 422);
    13         frame.setResizable(false);
    14         frame.setLocationRelativeTo(null);  // center frame on screen
    15         frame.setVisible(true);
    16     }
    17 }
    View Code

    Data.java

      1 import java.io.File;
      2 import java.io.FileNotFoundException;
      3 import java.util.Scanner;
      4 
      5 public class Data {
      6     private String title,xName, yName;
      7     private int n;
      8     private double[] x, y;
      9     private double sX, sXX, sY, sYY, sXY, minX, minY, maxX, maxY;
     10     private double meanX, meanY, slope, intercept, corrCoef;
     11 
     12     public Data(File inputFile) {
     13         try {
     14             Scanner input = new Scanner(inputFile);
     15             title = input.nextLine();
     16             n = input.nextInt();
     17             xName = input.next();
     18             yName = input.next();
     19             input.nextLine();
     20             x = new double[n];
     21             y = new double[n];
     22             minX = minY = Double.POSITIVE_INFINITY;
     23             maxX = maxY = Double.NEGATIVE_INFINITY;
     24             for (int i = 0; i < n; i++) {
     25                 double xi = x[i] = input.nextDouble();
     26                 double yi = y[i] = input.nextDouble();
     27                 sX += xi;
     28                 sXX += xi*xi;
     29                 sY += yi;
     30                 sYY += yi*yi;
     31                 sXY += xi*yi;
     32                 minX = (xi < minX? xi: minX);
     33                 minY = (yi < minY? yi: minY);
     34                 maxX = (xi > maxX? xi: maxX);
     35                 maxY = (yi > maxY? yi: maxY);
     36             }
     37             meanX = sX/n;
     38             meanY = sY/n;
     39             slope = (n*sXY - sX*sY)/(n*sXX - sX*sX);
     40             intercept = meanY - slope*meanX;
     41             corrCoef = slope*Math.sqrt((n*sXX - sX*sX)/(n*sYY - sY*sY));
     42         } catch (FileNotFoundException e) {
     43             System.err.println(e);
     44         }
     45     }
     46 
     47     public String getTitle() {
     48         return title;
     49     }
     50 
     51     public String getXName() {
     52         return xName;
     53     }
     54 
     55     public String getYName() {
     56         return yName;
     57     }
     58 
     59     public int getN() {
     60         return n;
     61     }
     62 
     63     public double[] getX() {
     64         return x;
     65     }
     66 
     67     public double[] getY() {
     68         return y;
     69     }
     70 
     71     public double getMeanX() {
     72         return meanX;
     73     }
     74 
     75     public double getMeanY() {
     76         return meanY;
     77     }
     78 
     79     public double getSlope() {
     80         return slope;
     81     }
     82 
     83     public double getIntercept() {
     84         return intercept;
     85     }
     86 
     87     public double getCorrCoef() {
     88         return corrCoef;
     89     }
     90     
     91     public double[][] getTable() {
     92         double[][] table = new double[n][2];
     93         for (int i = 0; i < n; i++) {
     94             table[i][0] = x[i];
     95             table[i][1] = y[i];
     96         }
     97         return table;
     98     }
     99 
    100     public double getMinX() {
    101         return minX;
    102     }
    103 
    104     public double getMinY() {
    105         return minY;
    106     }
    107 
    108     public double getMaxX() {
    109         return maxX;
    110     }
    111 
    112     public double getMaxY() {
    113         return maxY;
    114     }
    115 }
    View Code

    RegressionPanal.java

    import java.awt.BasicStroke;
    import java.awt.Color;
    import java.awt.Graphics;
    import java.awt.Graphics2D;
    import javax.swing.JPanel;
    
    public class RegressionPanel extends JPanel {
        private static final int WIDTH=500, HEIGHT=400, BUFFER=28, MARGIN=40;
        private final Data data;
        private double xMin, xMax, yMin, yMax, xRange, yRange, gWidth, gHeight;
        private double slope, intercept;
    
        public RegressionPanel(Data data) {
            this.data = data;
            this.setSize(WIDTH, HEIGHT);
            this.xMin = data.getMinX();
            this.xMax = data.getMaxX();
            this.yMin = data.getMinY();
            this.yMax = data.getMaxY();
            this.slope = data.getSlope();
            this.intercept = data.getIntercept();
            this.xRange = xMax - xMin;
            this.yRange = yMax - yMin;
            this.gWidth = WIDTH - 2*MARGIN - BUFFER;
            this.gHeight = HEIGHT - 2*MARGIN - BUFFER;
            setBackground(Color.WHITE);
        }
        
        @Override
        public void paintComponent(Graphics g) {
            super.paintComponent(g);
            Graphics2D g2 = (Graphics2D)g;
            g2.setStroke(new BasicStroke(1));
            drawGrid(g2);
            drawPoints(g2, data.getX(), data.getY());
            drawLine(g2);
        }
    
        private void drawGrid(Graphics2D g2) {
            g2.setStroke(new BasicStroke(1));
            double xGd = Math.pow(10, Math.floor(Math.log10(xRange)));
            int xd = dToI(xGd);
            int x0 = dToI(xGd*Math.floor(xMin/xGd));
            int xn = dToI(xGd*Math.ceil(xMax/xGd));
            for (int xi = x0; xi <= xn; xi += xd) {
                g2.setColor(Color.LIGHT_GRAY);
                int p = f(xi);
                g2.drawLine(p, 0, p, HEIGHT-18);  // vertical lines
                g2.setColor(Color.BLACK);
                g2.drawString(""+xi, p-8, HEIGHT-4);
            }
            double yGd = Math.pow(10, Math.floor(Math.log10(yRange)));
            int yd = dToI(yGd);
            int y0 = dToI(xGd*Math.floor(xMin/yGd));
            int yn = dToI(xGd*Math.ceil(yMax/yGd));
            for (int yi = y0; yi <= yn; yi += yd) {
                g2.setColor(Color.LIGHT_GRAY);
                int q = g(yi);
                g2.drawLine(BUFFER, q, WIDTH, q);  // horizontal lines
                g2.setColor(Color.LIGHT_GRAY);
                g2.setColor(Color.BLACK);
                g2.drawString((yi<100?"  ":"")+yi, 2, q+5);
            }
        }
        
        private void drawPoints(Graphics2D g2, double[] x, double[] y) {
            g2.setColor(Color.BLACK);
            for (int i = 0; i < x.length; i++) {
                int u = f(x[i]);
                int v = g(y[i]);
                g2.fillOval(u-3, v-3, 6, 6);  // coordinates are at NW corners
            }
        }
        
        private void drawLine(Graphics2D g2) {
            g2.setColor(Color.BLUE);
            g2.setStroke(new BasicStroke(2));
            int p0 = BUFFER;
            int q0 = g(yLine(fInv(p0)));
            int p1 = WIDTH;
            int q1 = g(yLine(fInv(p1)));
            g2.drawLine(p0, q0, p1, q1);
        }
        
        private double yLine(double x) {
            return slope*x + intercept;
        }
        
        private int dToI(double x) {
            return (int)Math.round(x);
        }
        
        private int f(double x) {
            return dToI((x - xMin)*gWidth/xRange) + BUFFER + MARGIN;
        }
        
        private int g(double y) {
            return dToI(gHeight - (y - yMin)*gHeight/yRange) + MARGIN;
        }
        
        private double fInv(int p) {
            return (p - BUFFER - MARGIN)*xRange/gWidth + xMin;
        }
        
        private double gInv(int q) {
            return yMin + (gHeight + MARGIN - q)*yRange/gHeight;
        }
    }
    View Code

    多项式回归

    • 需求:已知刹车速度和距离的数据,求解
    • 实现:最小二乘法,解方程组,LU分解
     1 import org.apache.commons.math3.linear.*;
     2 
     3 public class Example4 {
     4     static double[] x = {20, 30, 40, 50, 60, 70};
     5     static double[] y = {52, 87, 136, 203, 290, 394};
     6     static int n = y.length;  // 6
     7 
     8     public static void main(String[] args) {
     9         double[][] a = new double[3][3];
    10         double[] w = new double[3];
    11         deriveNormalEquations(a, w);
    12         printNormalEquations(a, w);
    13         double[] b = solveNormalEquations(a, w);
    14         printResults(b);
    15     }
    16 
    17     public static void deriveNormalEquations(double[][] a, double[] w) {
    18         for (int i = 0; i < n; i++) {
    19             double xi = x[i];
    20             double yi = y[i];
    21             a[0][0] = n;
    22             a[0][1] = a[1][0] += xi;
    23             a[0][2] = a[1][1] = a[2][0] += xi*xi;
    24             a[1][2] = a[2][1] += xi*xi*xi;
    25             a[2][2] += xi*xi*xi*xi;
    26             w[0] += yi;
    27             w[1] += xi*yi;
    28             w[2] += xi*xi*yi;
    29         }
    30     }
    31 
    32     public static void printNormalEquations(double[][] a, double[] w) {
    33         for (int i = 0; i < 3; i++) {
    34             System.out.printf("%8.0fb0 + %6.0fb1 + %8.0fb2 = %7.0f%n",
    35                     a[i][0], a[i][1], a[i][2], w[i]);
    36         }
    37     }
    38 
    39     /*  Solves the matrix equation a*b = w for b[], representing a[] 
    40         as RealMatrix m and b[] as RealVector v: 
    41      */
    42     private static double[] solveNormalEquations(double[][] a, double[] w) {
    43             RealMatrix m = new Array2DRowRealMatrix(a, false);
    44             LUDecomposition lud = new LUDecomposition(m);
    45             DecompositionSolver solver = lud.getSolver();
    46             RealVector v = new ArrayRealVector(w, false);
    47             return solver.solve(v).toArray();
    48     }
    49     
    50     private static void printResults(double[] b) {
    51         System.out.printf("f(t) = %.2f + %.3ft + %.5ft^2%n", b[0], b[1], b[2]);
    52         System.out.printf("f(55) = %.1f%n", f(55, b));
    53     }
    54     
    55     private static double f(double t, double[] b) {
    56         return b[0] + b[1]*t + b[2]*t*t;
    57     }
    58 }
    View Code

    6b0 + 270b1 + 13900b2 = 1162
    270b0 + 13900b1 + 783000b2 = 64220
    13900b0 + 783000b1 + 46750000b2 = 3798800
    f(t) = 40.73 + -1.170t + 0.08875t^2
    f(55) = 244.8

    多元线性回归

    • 需求:变量y依赖于多个变量
    • 实现:直接求解或通过Apache Commons

    Example5.java

     1 import org.apache.commons.math3.linear.*;
     2 
     3 public class Example5 {
     4     static double[] x = {10, 9, 12, 10, 9, 10, 8, 11};
     5     static double[] y = {59, 57, 61, 52, 48, 55, 51, 62};
     6     static double[] z = {71, 68, 76, 56, 57, 77, 55, 67};
     7     static int n = z.length;  // 8
     8             
     9     public static void main(String[] args) {
    10         double[][] a = new double[3][3];
    11         double[] w = new double[3];
    12         deriveNormalEquations(a, w);
    13         printNormalEquations(a, w);
    14         double[] b = solveNormalEquations(a, w);
    15         printResults(b);
    16     }
    17 
    18     public static void deriveNormalEquations(double[][] a, double[] w) {
    19         for (int i = 0; i < n; i++) {
    20             double xi = x[i];
    21             double yi = y[i];
    22             double zi = z[i];
    23             a[0][0] = n;
    24             a[0][1] = a[1][0] += xi;
    25             a[0][2] = a[2][0] += yi;
    26             a[1][1] += xi*xi;
    27             a[1][2] = a[2][1] += xi*yi;
    28             a[2][2] += yi*yi;
    29             w[0] += zi;
    30             w[1] += xi*zi;
    31             w[2] += yi*zi;
    32         }
    33     }
    34 
    35     public static void printNormalEquations(double[][] a, double[] w) {
    36         for (int i = 0; i < 3; i++) {
    37             System.out.printf("%6.0fx0 + %4.0fx1 + %5.0fx2 = %5.0f%n",
    38                     a[i][0], a[i][1], a[i][2], w[i]);
    39         }
    40     }
    41 
    42     private static double[] solveNormalEquations(double[][] a, double[] w) {
    43         RealMatrix m = new Array2DRowRealMatrix(a, false);
    44         LUDecomposition lud = new LUDecomposition(m);
    45         DecompositionSolver solver = lud.getSolver();
    46         RealVector v = new ArrayRealVector(w, false);
    47         return solver.solve(v).toArray();
    48     }
    49     
    50     private static void printResults(double[] b) {
    51         System.out.printf("f(s, t) = %.2f + %.2fs + %.2ft%n", b[0], b[1], b[2]);
    52         System.out.printf("f(10, 59) = %.1f%n", f(10, 59, b));
    53         System.out.printf("f(9, 57) = %.1f%n", f(9, 57, b));
    54         System.out.printf("f(11, 64) = %.1f%n", f(11, 64, b));
    55     }
    56     
    57     private static double f(double s, double t, double[] b) {
    58         return b[0] + b[1]*s + b[2]*t;
    59     }
    60 }
    View Code

    Example6.java

     1 import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
     2 
     3 public class Example6 {
     4     static double[][] x = { {10, 59}, {9, 57}, {12, 61}, {10, 52}, {9, 48}, 
     5             {10, 55}, {8, 51}, {11, 62} };
     6     static double[] y = {71, 68, 76, 56, 57, 77, 55, 67};
     7 
     8     public static void main(String[] args) {
     9         OLSMultipleLinearRegression mlr = new OLSMultipleLinearRegression();
    10         mlr.newSampleData(y, x);
    11         double[] b = mlr.estimateRegressionParameters();
    12         printResults(b);
    13     }
    14     
    15     private static void printResults(double[] b) {
    16         System.out.printf("f(s, t) = %.2f + %.2fs + %.2ft%n", b[0], b[1], b[2]);
    17         System.out.printf("f(10, 59) = %.1f%n", f(10, 59, b));
    18         System.out.printf("f(9, 57) = %.1f%n", f(9, 57, b));
    19         System.out.printf("f(11, 64) = %.1f%n", f(11, 64, b));
    20     }
    21     
    22     private static double f(double s, double t, double[] b) {
    23         return b[0] + b[1]*s + b[2]*t;
    24     }
    25 }
    View Code

    8x0 + 79x1 + 445x2 = 527
    79x0 + 791x1 + 4427x2 = 5254
    445x0 + 4427x1 + 24929x2 = 29543
    f(s, t) = -5.75 + 1.55s + 1.01t
    f(10, 59) = 69.5
    f(9, 57) = 65.9
    f(11, 64) = 76.1

  • 相关阅读:
    $ [Contest #4]$求和 思博题
    洛谷$P1864 [NOI2009]$二叉查找树 区间$dp$
    洛谷$P4045 [JSOI2009]$密码 $dp$+$AC$自动机
    $bzoj2560$ 串珠子 容斥+$dp$
    洛谷$P1600$ 天天爱跑步 树上差分
    $loj526 [LibreOJ eta Round #4]$ 子集 图论
    $CF888G Xor-MST$ 最小生成树
    $bzoj4152 The Captain$ 最短路
    洛谷$P3645 [APIO2015]$雅加达的摩天楼 最短路
    $bzoj4722$ 由乃 搜索
  • 原文地址:https://www.cnblogs.com/cxc1357/p/14687794.html
Copyright © 2011-2022 走看看