zoukankan      html  css  js  c++  java
  • Logistic Regression学习

    Logistic Regression 就是一个被logistic方程归一化后的线性回归。

    对于二分类问题,我们输入向量x[x1,x2...xn],Θ(θ0,θ1,θ2,···θn)为我们的学习算法所学到的参数,分类结果为0和1。令


    可写为


    用一个sigmoid函数去做映射


    g(z)的值接近0则输入样本就归为0类,否则归为1类。所以现在我们需要训练参数Θ(θ0,θ1,θ2,···θn)。我们用梯度下降法去训练参数。


    Logistic Regression只能处理两分类问题,在其基础上衍生出来的softmax可以用于多分类,且必须线性可分。


    //LogisticRegression.h
    class LogisticRegression {
    
    public:
      int N;  // num of inputs
      int n_in;
      int n_out;
      double **W;
      double *b;
      LogisticRegression(int, int, int);
      ~LogisticRegression();
      void train(int*, int*, double);
      void softmax(double*);
      void predict(int*, double*);
    };


    //LogisticRegression.cpp
    #include <iostream>
    #include <string>
    #include <math.h>
    #include "LogisticRegression.h"
    using namespace std;
    
    
    LogisticRegression::LogisticRegression(int size, int in, int out) {
      N = size;
      n_in = in;
      n_out = out;
    
      // initialize W, b
      W = new double*[n_out];
      for(int i=0; i<n_out; i++) W[i] = new double[n_in];
      b = new double[n_out];
    
      for(int i=0; i<n_out; i++) {
        for(int j=0; j<n_in; j++) {
          W[i][j] = 0;
        }
        b[i] = 0;
      }
    }
    
    LogisticRegression::~LogisticRegression() {
      for(int i=0; i<n_out; i++) delete[] W[i];
      delete[] W;
      delete[] b;
    }
    
    
    void LogisticRegression::train(int *x, int *y, double lr) {
      double *p_y_given_x = new double[n_out];
      double *dy = new double[n_out];
    
      for(int i=0; i<n_out; i++) {
        p_y_given_x[i] = 0;
        for(int j=0; j<n_in; j++) {
          p_y_given_x[i] += W[i][j] * x[j];
        }
        p_y_given_x[i] += b[i];
      }
      softmax(p_y_given_x);
    
      for(int i=0; i<n_out; i++) {
        dy[i] = y[i] - p_y_given_x[i];
    
        for(int j=0; j<n_in; j++) {
          W[i][j] += lr * dy[i] * x[j] / N;
        }
    
        b[i] += lr * dy[i] / N;
      }
      delete[] p_y_given_x;
      delete[] dy;
    }
    
    void LogisticRegression::softmax(double *x) {
      double max = 0.0;
      double sum = 0.0;
      
      for(int i=0; i<n_out; i++) if(max < x[i]) max = x[i];
      for(int i=0; i<n_out; i++) {
        x[i] = exp(x[i] - max);
        sum += x[i];
      } 
    
      for(int i=0; i<n_out; i++) x[i] /= sum;
    }
    
    void LogisticRegression::predict(int *x, double *y) {
      for(int i=0; i<n_out; i++) {
        y[i] = 0;
        for(int j=0; j<n_in; j++) {
          y[i] += W[i][j] * x[j];
        }
        y[i] += b[i];
      }
    
      softmax(y);
    }
    
    
    void test_lr() {
      srand(0);
      
      double learning_rate = 0.1;
      int n_epochs = 500;
    
      int train_N = 6;
      int test_N = 2;
      int n_in = 6;
      int n_out = 2;
    
    
      // training data
      int train_X[6][6] = {
        {1, 1, 1, 0, 0, 0},
        {1, 0, 1, 0, 0, 0},
        {1, 1, 1, 0, 0, 0},
        {0, 0, 1, 1, 1, 0},
        {0, 0, 1, 1, 0, 0},
        {0, 0, 1, 1, 1, 0}
      };
    
      int train_Y[6][2] = {
        {1, 0},
        {1, 0},
        {1, 0},
        {0, 1},
        {0, 1},
        {0, 1}
      };
    
    
      // construct LogisticRegression
      LogisticRegression classifier(train_N, n_in, n_out);
    
    
      // train online
      for(int epoch=0; epoch<n_epochs; epoch++) {
        for(int i=0; i<train_N; i++) {
          classifier.train(train_X[i], train_Y[i], learning_rate);
        }
        // learning_rate *= 0.95;
      }
    
    
      // test data
      int test_X[2][6] = {
        {1, 0, 1, 0, 0, 0},
        {0, 0, 1, 1, 1, 0}
      };
    
      double test_Y[2][2];
    
    
      // test
      for(int i=0; i<test_N; i++) {
        classifier.predict(test_X[i], test_Y[i]);
        for(int j=0; j<n_out; j++) {
          cout << test_Y[i][j] << " ";
        }
        cout << endl;
      }
    
    }
    
    
    int main() {
      test_lr();
      return 0;
    }


    版权声明:

  • 相关阅读:
    继承
    构造函数,重载
    Java Scanner学习记录
    20131204-数据库基础
    20131128-正则表达式与委托
    20131125-序列化与正则表达式
    20131127-正则表达式
    20131120-接口字符串-小鸭子练习
    20131118-静态类抽象类-外部设备
    20131117-练习面向对象
  • 原文地址:https://www.cnblogs.com/walccott/p/4957099.html
Copyright © 2011-2022 走看看