zoukankan      html  css  js  c++  java
  • 线性、逻辑回归的java实现

      线性回归和逻辑回归的实现大体一致,将其抽象出一个抽象类Regression,包含整体流程,其中有三个抽象函数,将在线性回归和逻辑回归中重写。

      将样本设为Sample类,其中采用数组作为特征的存储形式。

    1. 样本类Sample

     1 public class Sample {
     2     
     3     double[] features;
     4     int feaNum; // the number of sample's features
     5     double value; // value of sample in regression
     6     int label; // class of sample
     7     
     8     public Sample(int number) {
     9         feaNum = number;
    10         features = new double[feaNum];
    11     }
    12     
    13     public void outSample() {
    14         System.out.println("The sample's features are:");
    15         for(int i = 0; i < feaNum; i++) {
    16             System.out.print(features[i] + " ");
    17         }
    18         System.out.println();
    19         System.out.println("The label is: " + label);
    20         System.out.println("The value is: " + value);
    21     }
    22 }

    2. 抽象类Regression

    public abstract class Regression {
    
        double[] theta; //parameters
        int paraNum; //the number of parameters
        double rate; //learning rate
        Sample[] sam; // samples
        int samNum; // the number of samples
        double th; // threshold value
        
        /**
         * initialize the samples
         * @param s : training set
         * @param num : the number of training samples
         */
        public void Initialize(Sample[] s, int num) {
            samNum = num;
            sam = new Sample[samNum];
            for(int i = 0; i < samNum; i++) {
                sam[i] = s[i];
            }
        }
        
        /**
         * initialize all parameters
         * @param para : theta
         * @param learning_rate 
         * @param threshold 
         */
        public void setPara(double[] para, double learning_rate, double threshold) {
            paraNum = para.length;
            theta = para;
            rate = learning_rate;
            th = threshold;
        }
        
        /**
         * predicte the value of sample s
         * @param s : prediction sample
         * @return : predicted value
         */
        public abstract double PreVal(Sample s);
        
        /**
         * calculate the cost of all samples
         * @return : the cost
         */
        public abstract double CostFun();
        
        /**
         * update the theta
         */
        public abstract void Update();
        
        public void OutputTheta() {
            System.out.println("The parameters are:");
            for(int i = 0; i < paraNum; i++) {
                System.out.print(theta[i] + " ");
            }
            System.out.println(CostFun());
        }
    }

    3. 线性回归LinearRegression

    public class LinearRegression extends Regression{
    
        public double PreVal(Sample s) {
            double val = 0;
            for(int i = 0; i < paraNum; i++) {
                val += theta[i] * s.features[i];
            }
            return val;
        }
        
        public double CostFun() {
            double sum = 0;
            for(int i = 0; i < samNum; i++) {
                double d = PreVal(sam[i]) - sam[i].value;
                sum += Math.pow(d, 2);
            }
            return sum / (2*samNum);
        }
        
        public void Update() {
             double former = 0; // the cost before update
             double latter = CostFun(); // the cost after updatedouble[] p = new double[paraNum];
             do {
                 former = latter;
                 //update theta
                 for(int i = 0; i < paraNum; i++) {
                     // for theta[i]
    double d = 0; for(int j = 0; j < samNum; j++) { d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i]; } p[i] -= (rate * d) / samNum; } theta = p; latter = CostFun();

             if(former - latter < 0){
              System.out.println("α is larger!!!");
              break;
            }

           }while(former - latter > th);

        }
    
    }

    4. 逻辑回归LogisticRegression

    public class LogisticRegression extends Regression{
    
        public double PreVal(Sample s) {
            double val = 0;
            for(int i = 0; i < paraNum; i++) {
                val += theta[i] * s.features[i];
            }
            return 1/(1 + Math.pow(Math.E, -val));
        }
    
        public double CostFun() {
            double sum = 0;
            for(int i = 0; i < samNum; i++) {
                double p = PreVal(sam[i]);
                double d = Math.log(p) * sam[i].label + (1 - sam[i].label) * Math.log(1 - p);
                sum += d;
            }
            return -1 * (sum / samNum);
        }
        
        public void Update() {
             double former = 0; // the cost before update
             double latter = CostFun(); // the cost after update
             double d = 0;
             double[] p = new double[paraNum];
             do {
                 former = latter;
                 //update theta
                 for(int i = 0; i < paraNum; i++) {
                     // for theta[i]
    double d = 0;
    for(int j = 0; j < samNum; j++) { d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i]; } p[i] -= (rate * d) / samNum; } latter = CostFun();

             if(former - latter < 0){
              System.out.println("α is larger!!!");
              break;
             }

          }while(former - latter > th);

             theta = p;
        }
    }

    5. 使用的线性回归样本

    x0 x1 x2 x3 x4 y
    1 2104 5 1 45 460
    1 1416 3 2 40 232
    1 1534 3 2 30 315
    1 852 2 1 36 178
    1 1254 3 3 45 321
    1 987 2 2 35 241
    1 1054 3 2 30 287
    1 645 2 3 25 87
    1 542 2 1 30 94
    1 1065 3 1 25 241
    1 2465 7 2 50 687
    1 2410 6 1 45 654
    1 1987 4 2 45 436
    1 457 2 3 35 65
    1 587 2 2 25 54
    1 468 2 1 40 87
    1 1354 3 1 35 215
    1 1587 4 1 45 345
    1 1789 4 2 35 325
    1 2500 8 2 40 720

    6. 线性回归测试

    import java.io.IOException;
    import java.io.RandomAccessFile;
    
    public class Test {
    
        public static void main(String[] args) throws IOException {
            //read Sample.txt
            Sample[] sam = new Sample[25];
            int w = 0;
                    
            long filePoint = 0;
            String s;
            RandomAccessFile file = new RandomAccessFile("resource//LinearSample.txt", "r");
            long fileLength = file.length();
                    
            while(filePoint < fileLength) {
                s = file.readLine();
                //s --> sample
                String[] sub = s.split(" ");
                sam[w] = new Sample(sub.length - 1);
                for(int i = 0; i < sub.length; i++) {
                    if(i == sub.length - 1) {
                        sam[w].value = Double.parseDouble(sub[i]);
                    }
                    else {
                        sam[w].features[i] = Double.parseDouble(sub[i]);
                    }
                }//for
                w++;
                filePoint = file.getFilePointer();
            }//while read file
            
            LinearRegression lr = new LinearRegression();
            double[] para = {0,0,0,0,0};
            double rate = 0.5;
            double th = 0.001;
            lr.Initialize(sam, w);
            lr.setPara(para, rate, th);
            lr.Update();
            lr.OutputTheta();
        }
        
    }

    7. 使用的逻辑回归样本

    x0 x1 x2 class
    1 0.23 0.35 0
    1 0.32 0.24 0
    1 0.6 0.12 0
    1 0.36 0.54 0
    1 0.02 0.89 0
    1 0.36 -0.12 0
    1 -0.45 0.62 0
    1 0.56 0.42 0
    1 0.4 0.56 0
    1 0.46 0.51 0
    1 1.2 0.32 1
    1 0.6 0.9 1
    1 0.32 0.98 1
    1 0.2 1.3 1
    1 0.15 1.36 1
    1 0.54 0.98 1
    1 1.36 1.05 1
    1 0.22 1.65 1
    1 1.65 1.54 1
    1 0.25 1.68 1

    8. 逻辑回归测试

    import java.io.IOException;
    import java.io.RandomAccessFile;
    
    public class Test {
    
        public static void main(String[] args) throws IOException {
            //read Sample.txt
            Sample[] sam = new Sample[25];
            int w = 0;
                    
            long filePoint = 0;
            String s;
            RandomAccessFile file = new RandomAccessFile("resource//LogisticSample.txt", "r");
            long fileLength = file.length();
                    
            while(filePoint < fileLength) {
                s = file.readLine();
                //s --> sample
                String[] sub = s.split(" ");
                sam[w] = new Sample(sub.length - 1);
                for(int i = 0; i < sub.length; i++) {
                    if(i == sub.length - 1) {
                        sam[w].label = Integer.parseInt(sub[i]);
                    }
                    else {
                        sam[w].features[i] = Double.parseDouble(sub[i]);
                    }
                }//for
                //sam[w].outSample();
                w++;
                filePoint = file.getFilePointer();
            }//while read file
            
            LogisticRegression lr = new LogisticRegression();
            double[] para = {0,0,0};
            double rate = 0.5;
            double th = 0.001;
            lr.Initialize(sam, w);
            lr.setPara(para, rate, th);
            lr.Update();
            lr.OutputTheta();
        }
        
    }
  • 相关阅读:
    668. Kth Smallest Number in Multiplication Table
    658. Find K Closest Elements
    483. Smallest Good Base
    475. Heaters
    454. 4Sum II
    441. Arranging Coins
    436. Find Right Interval
    410. Split Array Largest Sum
    392. Is Subsequence
    378. Kth Smallest Element in a Sorted Matrix
  • 原文地址:https://www.cnblogs.com/datamining-bio/p/9240378.html
Copyright © 2011-2022 走看看