zoukankan      html  css  js  c++  java
  • 【转】Derivation of the Normal Equation for linear regression

    I was going through the Coursera "Machine Learning" course, and in the section on multivariate linear regression something caught my eye. Andrew Ng presented the Normal Equation as an analytical solution to the linear regression problem with a least-squares cost function. He mentioned that in some cases (such as for small feature sets) using it is more effective than applying gradient descent; unfortunately, he left its derivation out.

    Here I want to show how the normal equation is derived.

    First, some terminology. The following symbols are compatible with the machine learning course, not with the exposition of the normal equation on Wikipedia and other sites - semantically it's all the same, just the symbols are different.

    Given the hypothesis function:

    <math>

    We'd like to minimize the least-squares cost:

    <math>

    Where <math> is the i-th sample (from a set of m samples) and <math> is the i-th expected result.

    To proceed, we'll represent the problem in matrix notation; this is natural, since we essentially have a system of linear equations here. The regression coefficients <math> we're looking for are the vector:

    <math>

    Each of the m input samples is similarly a column vector with n+1 rows, <math> being 1 for convenience. So we can now rewrite the hypothesis function as:

    <math>

    When this is summed over all samples, we can dip further into matrix notation. We'll define the "design matrix" X (uppercase X) as a matrix of m rows, in which each row is the i-th sample (the vector <math>). With this, we can rewrite the least-squares cost as following, replacing the explicit sum by matrix multiplication:

    <math>

    Now, using some matrix transpose identities, we can simplify this a bit. I'll throw the <math> part away since we're going to compare a derivative to zero anyway:

    <math> <math>

    Note that <math> is a vector, and so is y. So when we multiply one by another, it doesn't matter what the order is (as long as the dimensions work out). So we can further simplify:

    <math>

    Recall that here <math> is our unknown. To find where the above function has a minimum, we will derive by <math> and compare to 0. Deriving by a vector may feel uncomfortable, but there's nothing to worry about. Recall that here we only use matrix notation to conveniently represent a system of linear formulae. So we derive by each component of the vector, and then combine the resulting derivatives into a vector again. The result is:

    <math>

    Or:

    <math>

    [Update 27-May-2015: I've written another post that explains in more detail how these derivatives are computed.]

    Now, assuming that the matrix <math> is invertible, we can multiply both sides by <math> and get:

    <math>

    Which is the normal equation.

  • 相关阅读:
    字符流与字节流的区别
    向文件尾部追加内容
    Hashmap实现原理及扩容机制详解
    HashMap的put和get方法原理
    关于数字化工厂&智能工厂建设 IT 经验总结
    @所有人,网易数字+大会报名通道正式开启!
    WinForm程序打包1之快速入门
    解决安装.NET Framework不受信任的根证书
    Cannot resolve com.sun:tools:1.8.0 错误解决
    IDEA 2020报“java:程序包XXXX不存在”或“java:找不到符号”
  • 原文地址:https://www.cnblogs.com/immortal-worm/p/5719221.html
Copyright © 2011-2022 走看看