zoukankan      html  css  js  c++  java
  • 最小二乘线性回归与梯度下降算法

    线性回归

    预测函数:

    \[h_\theta(x)=\theta_0 + \theta_1x_1 + \theta_2x_2 \]

    \(x_0 = 1\)

    \[ h_\theta(x) = \sum_{i=0}^n\theta_ix_i \]

    其中 \(\theta\)这参数,n表示特征数量


    成本函数:

    \[J(\theta) = \frac{1}{2}\sum_{i=1}^m(h_\theta(x^{(i)}) - y^{(i)})^2 \]

    其中m表示样本数量,\((x^{(i)},y^{(i)})\)表示第i个样本对儿,1/2是为了方便计算


    我们目标:

    \[\min_\theta J(\theta) \]

    最小二乘成本函数的概率学解释(并不是唯一的解释)

    假设该线性模型符合高斯分布

    \[p(y^{(i)}|x^{(i)};\theta) = \frac{1}{\sqrt{2\pi}\sigma}\exp\left(-\sqrt{(y^{(i)}-\theta^Tx^{(i)^2})}{2\sigma^2}\right) \]

    似然函数为:

    \[L(\theta) = L(\theta;X,\vec{y}) = p(\vec{y}|X;θ) \]

    因此,线性回归的似然函数:

    \[\begin{align} L(\theta) & = \prod_{i=1}^mp(y^{(i)}|x^{(i)};\theta) \\ & = \prod_{i=1}^m\frac{1}{\sqrt{2\pi}\sigma}\exp\left(-\frac{((y^{(i)}) - \theta^Tx^{(i)})^2}{2\sigma^2}\right) \end{align} \]

    对数似然度为

    \[\begin{align} l(\theta) & = \log L(\theta)\\ & = \log\prod_{i=1}^m\frac{1}{\sqrt{2\pi}\sigma}\exp\left(-\frac{((y^{(i)}) - \theta^Tx^{(i)})^2}{2\sigma^2}\right)\\ & = \sum_{i=1}^m\log\frac{1}{{2\sigma^2}}\exp\left(-\frac{((y^{(i)}) - \theta^Tx^{(i)})^2}{2\sigma^2}\right)\\ & = m\log\frac{1}{2\pi\sigma}-\frac{1}{\sigma^2}\cdot\frac{1}{2}\sum_{i=1}^m(y^{(i)} - \theta^Tx^{(i)})^2 \end{align} \]

    最大化\(l(\theta)\)相当于最小化,也就是成本函数

    \[\frac{1}{2}\sum_{i=1}^m(y^{(i)} - \theta^Tx^{(i)})^2 \]

    梯度下降

    步骤

    1. 初始化\(\theta\)
    2. 改变\(\theta\)使\(J(\theta)\)降低
    3. 终止

    每一步使:

    \[\theta_i:=\theta_i-\alpha\frac{\partial}{\partial\theta_i}J(\theta) \]

    \(\theta_i\)表示第个参数,:=表示赋值符号,\(\alpha\)表示步长也叫学习速率,\(\frac{\partial}{\partial\theta_i}\)是对\(\theta_i\)求偏导数

    对于单个样本:

    \[\begin{align} \frac{\partial}{\partial\theta_i}J(\theta) & = \frac{\partial}{\partial\theta_i}\frac{1}{2}(h_\theta(x) - y)^2 \\ & = 2\cdot\frac{1}{2}(h_\theta(x) - y) \cdot \frac{\partial}{\partial\theta_i}(h_\theta(x)-y) \\ & = (h_\theta(x) - y)\cdot\frac{\partial}{\partial\theta_i}(\theta_0x_0+\theta_1x_1+... + \theta_nx_n-y) \\ & = (h_\theta(x) - y)\cdot x_i \end{align} \]

    (1)-(2) 根据链式求导法则,(3)-(4)\((\theta_0x_0+\theta_1x_1+... + \theta_nx_n-y)\)只有\(\theta_i\)起作用
    因此

    \[\theta_i:=\theta_i-\alpha(h_\theta(x) - y)\cdot x_i \]


    对于多个样本:

    \[\theta_i:=\theta_i-\sum_{j=1}^m\alpha(h_\theta(x^{(j)}) - y^{(j)})\cdot x_i^{(j)} \]

    批梯度下降算法代码

    public static void main(String[] args) {
        //训练集合输入值
        double x[][] = {
                {1, 4},
                {2, 5},
                {5, 1},
                {4, 2}};
        //训练集合期望输出值
        double y[] = {
                19,
                26,
                19,
                20};
        //参数
        double theta[] = {0, 0};
        //步长
        double learningRate = 0.01;
        //样本集合长度
        int m = x.length;
        //迭代次数
        int iteration = 100;
        for (int index = 0; index < iteration; index++) {
            double err_sum = 0;
            double[] tmp = new double[theta.length];
            System.arraycopy(theta, 0, tmp, 0, theta.length);
            //i->m
            for (int j = 0; j < m; j++) {
                double h = 0;
                for (int i = 0; i < theta.length; i++) {//x[j][i]表示第j个样本的第i个特征
                    h += x[j][i] * theta[i];
                }
                err_sum = h - y[j];
                for (int i = 0; i < theta.length; i++) {
                    tmp[i] -= learningRate * err_sum * x[j][i];
                }
            }
            //批量更新参数
            theta = tmp;
        }
        for (int i = 0; i < theta.length; i++) {
            System.out.println("theta" + i + "=" + theta[i]);
        }
    }
    
  • 相关阅读:
    【配置属性】—Entity Framework 对应表字段的类型的设定配置方法
    EntityFrame Work 6 Code First 配置字段为varchar 类型
    Echarts xAxis boundaryGap
    JavaScript Array和string的转换
    SQL server :主键和外键
    SQL server :“增删改查” 之 “改”
    SQL server :“增删改查” 之 “删”
    SQL server :“增删改查” 之 “增”
    Oracle不能连接故障排除【TNS-12541:TNS:无监听程序】
    LNMP平台部署及应用
  • 原文地址:https://www.cnblogs.com/hwyang/p/6431250.html
Copyright © 2011-2022 走看看