zoukankan      html  css  js  c++  java
  • 线性、逻辑回归的java实现

      线性回归和逻辑回归的实现大体一致,将其抽象出一个抽象类Regression,包含整体流程,其中有三个抽象函数,将在线性回归和逻辑回归中重写。

      将样本设为Sample类,其中采用数组作为特征的存储形式。

    1. 样本类Sample

     1 public class Sample {
     2     
     3     double[] features;
     4     int feaNum; // the number of sample's features
     5     double value; // value of sample in regression
     6     int label; // class of sample
     7     
     8     public Sample(int number) {
     9         feaNum = number;
    10         features = new double[feaNum];
    11     }
    12     
    13     public void outSample() {
    14         System.out.println("The sample's features are:");
    15         for(int i = 0; i < feaNum; i++) {
    16             System.out.print(features[i] + " ");
    17         }
    18         System.out.println();
    19         System.out.println("The label is: " + label);
    20         System.out.println("The value is: " + value);
    21     }
    22 }

    2. 抽象类Regression

    public abstract class Regression {
    
        double[] theta; //parameters
        int paraNum; //the number of parameters
        double rate; //learning rate
        Sample[] sam; // samples
        int samNum; // the number of samples
        double th; // threshold value
        
        /**
         * initialize the samples
         * @param s : training set
         * @param num : the number of training samples
         */
        public void Initialize(Sample[] s, int num) {
            samNum = num;
            sam = new Sample[samNum];
            for(int i = 0; i < samNum; i++) {
                sam[i] = s[i];
            }
        }
        
        /**
         * initialize all parameters
         * @param para : theta
         * @param learning_rate 
         * @param threshold 
         */
        public void setPara(double[] para, double learning_rate, double threshold) {
            paraNum = para.length;
            theta = para;
            rate = learning_rate;
            th = threshold;
        }
        
        /**
         * predicte the value of sample s
         * @param s : prediction sample
         * @return : predicted value
         */
        public abstract double PreVal(Sample s);
        
        /**
         * calculate the cost of all samples
         * @return : the cost
         */
        public abstract double CostFun();
        
        /**
         * update the theta
         */
        public abstract void Update();
        
        public void OutputTheta() {
            System.out.println("The parameters are:");
            for(int i = 0; i < paraNum; i++) {
                System.out.print(theta[i] + " ");
            }
            System.out.println(CostFun());
        }
    }

    3. 线性回归LinearRegression

    public class LinearRegression extends Regression{
    
        public double PreVal(Sample s) {
            double val = 0;
            for(int i = 0; i < paraNum; i++) {
                val += theta[i] * s.features[i];
            }
            return val;
        }
        
        public double CostFun() {
            double sum = 0;
            for(int i = 0; i < samNum; i++) {
                double d = PreVal(sam[i]) - sam[i].value;
                sum += Math.pow(d, 2);
            }
            return sum / (2*samNum);
        }
        
        public void Update() {
             double former = 0; // the cost before update
             double latter = CostFun(); // the cost after updatedouble[] p = new double[paraNum];
             do {
                 former = latter;
                 //update theta
                 for(int i = 0; i < paraNum; i++) {
                     // for theta[i]
    double d = 0; for(int j = 0; j < samNum; j++) { d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i]; } p[i] -= (rate * d) / samNum; } theta = p; latter = CostFun();

             if(former - latter < 0){
              System.out.println("α is larger!!!");
              break;
            }

           }while(former - latter > th);

        }
    
    }

    4. 逻辑回归LogisticRegression

    public class LogisticRegression extends Regression{
    
        public double PreVal(Sample s) {
            double val = 0;
            for(int i = 0; i < paraNum; i++) {
                val += theta[i] * s.features[i];
            }
            return 1/(1 + Math.pow(Math.E, -val));
        }
    
        public double CostFun() {
            double sum = 0;
            for(int i = 0; i < samNum; i++) {
                double p = PreVal(sam[i]);
                double d = Math.log(p) * sam[i].label + (1 - sam[i].label) * Math.log(1 - p);
                sum += d;
            }
            return -1 * (sum / samNum);
        }
        
        public void Update() {
             double former = 0; // the cost before update
             double latter = CostFun(); // the cost after update
             double d = 0;
             double[] p = new double[paraNum];
             do {
                 former = latter;
                 //update theta
                 for(int i = 0; i < paraNum; i++) {
                     // for theta[i]
    double d = 0;
    for(int j = 0; j < samNum; j++) { d += (PreVal(sam[j]) - sam[j].value) * sam[j].features[i]; } p[i] -= (rate * d) / samNum; } latter = CostFun();

             if(former - latter < 0){
              System.out.println("α is larger!!!");
              break;
             }

          }while(former - latter > th);

             theta = p;
        }
    }

    5. 使用的线性回归样本

    x0 x1 x2 x3 x4 y
    1 2104 5 1 45 460
    1 1416 3 2 40 232
    1 1534 3 2 30 315
    1 852 2 1 36 178
    1 1254 3 3 45 321
    1 987 2 2 35 241
    1 1054 3 2 30 287
    1 645 2 3 25 87
    1 542 2 1 30 94
    1 1065 3 1 25 241
    1 2465 7 2 50 687
    1 2410 6 1 45 654
    1 1987 4 2 45 436
    1 457 2 3 35 65
    1 587 2 2 25 54
    1 468 2 1 40 87
    1 1354 3 1 35 215
    1 1587 4 1 45 345
    1 1789 4 2 35 325
    1 2500 8 2 40 720

    6. 线性回归测试

    import java.io.IOException;
    import java.io.RandomAccessFile;
    
    public class Test {
    
        public static void main(String[] args) throws IOException {
            //read Sample.txt
            Sample[] sam = new Sample[25];
            int w = 0;
                    
            long filePoint = 0;
            String s;
            RandomAccessFile file = new RandomAccessFile("resource//LinearSample.txt", "r");
            long fileLength = file.length();
                    
            while(filePoint < fileLength) {
                s = file.readLine();
                //s --> sample
                String[] sub = s.split(" ");
                sam[w] = new Sample(sub.length - 1);
                for(int i = 0; i < sub.length; i++) {
                    if(i == sub.length - 1) {
                        sam[w].value = Double.parseDouble(sub[i]);
                    }
                    else {
                        sam[w].features[i] = Double.parseDouble(sub[i]);
                    }
                }//for
                w++;
                filePoint = file.getFilePointer();
            }//while read file
            
            LinearRegression lr = new LinearRegression();
            double[] para = {0,0,0,0,0};
            double rate = 0.5;
            double th = 0.001;
            lr.Initialize(sam, w);
            lr.setPara(para, rate, th);
            lr.Update();
            lr.OutputTheta();
        }
        
    }

    7. 使用的逻辑回归样本

    x0 x1 x2 class
    1 0.23 0.35 0
    1 0.32 0.24 0
    1 0.6 0.12 0
    1 0.36 0.54 0
    1 0.02 0.89 0
    1 0.36 -0.12 0
    1 -0.45 0.62 0
    1 0.56 0.42 0
    1 0.4 0.56 0
    1 0.46 0.51 0
    1 1.2 0.32 1
    1 0.6 0.9 1
    1 0.32 0.98 1
    1 0.2 1.3 1
    1 0.15 1.36 1
    1 0.54 0.98 1
    1 1.36 1.05 1
    1 0.22 1.65 1
    1 1.65 1.54 1
    1 0.25 1.68 1

    8. 逻辑回归测试

    import java.io.IOException;
    import java.io.RandomAccessFile;
    
    public class Test {
    
        public static void main(String[] args) throws IOException {
            //read Sample.txt
            Sample[] sam = new Sample[25];
            int w = 0;
                    
            long filePoint = 0;
            String s;
            RandomAccessFile file = new RandomAccessFile("resource//LogisticSample.txt", "r");
            long fileLength = file.length();
                    
            while(filePoint < fileLength) {
                s = file.readLine();
                //s --> sample
                String[] sub = s.split(" ");
                sam[w] = new Sample(sub.length - 1);
                for(int i = 0; i < sub.length; i++) {
                    if(i == sub.length - 1) {
                        sam[w].label = Integer.parseInt(sub[i]);
                    }
                    else {
                        sam[w].features[i] = Double.parseDouble(sub[i]);
                    }
                }//for
                //sam[w].outSample();
                w++;
                filePoint = file.getFilePointer();
            }//while read file
            
            LogisticRegression lr = new LogisticRegression();
            double[] para = {0,0,0};
            double rate = 0.5;
            double th = 0.001;
            lr.Initialize(sam, w);
            lr.setPara(para, rate, th);
            lr.Update();
            lr.OutputTheta();
        }
        
    }
  • 相关阅读:
    Winform 时间
    button的后台点击事件
    Winform文本框只能输入限定的文本
    vue的生命周期函数
    ES6新增语法
    购物车案例(JavaScript动态效果)
    前端es6总结
    jQuery与vue的区别是什么?
    vue实现双向绑定原理
    JS实现简单分页功能
  • 原文地址:https://www.cnblogs.com/datamining-bio/p/9240378.html
Copyright © 2011-2022 走看看