zoukankan      html  css  js  c++  java
  • 机器学习学习笔记 梯度下降

    梯度下降算法的思维过程:

    x为训练数据输入值。y为训练数据输出值。θ 为 x的系数,也就是要求的。

    1.预测公式 h(x) = ∑θix。“使 θ尽可能的准确”,可以理解为理想情况下对每一组样本都有 ( h(x(i)) - y(i) )= 0 ,非理想情况下希望 J(θ) = ∑( ( h(x(i)) - y(i) )2 /2 )尽可能小。

    2.梯度下降的思路是: 先取一组随机的 θ 值,代入样本数据,通过求导等计算算出能进一步减小J(θ) 的 θ 值。重新 对θ 赋值,再不断进行这个步骤,直到 J(θ) 达到局部最小。 

    3.这个对 θ 重新赋值的算法是 “ θi := θi - ( aJ(θ) / aθi )” ,后面这一项的意思是 J(θ) 对 θi 求偏导 。假设只有一个训练样本时,计算的结果为 : θi := θ- ( h(x) - y )* xi 。

    4.当有m项训练数据时,这个公式就变成了 θi := θ- a∑j:1->m( h(x(j)) - y(j) ) * xi(j)  。这称成为批量梯度下降公式。a表示下降的步长,通常是手工设置。

    5.因为当m很大时,每一次下降都要对很大的样本数进行求和,非常耗时,所以有了随机梯度下降公式 : θi := θ- ( h(x(j)) - y(j) ) * xi(j)。即每一次下降只使用一个样本。

    为了简化这个公式以及更加明确对这个它的理解,一下引入了矩阵来重新对它定义。先引入几个公式和表示法。

    符号:

    θ

    θJ(θ) = [ aJ(θ)/aθ1 , aJ(θ)/aθ2 ... aJ(θ)/aθi ]      意思是,当 θ 为单行矩阵时,对 J(θ) 求导的结果也是一个单行矩阵,矩阵的每一项就是这个式子对当前的 θ 求偏导。

    推广到多行矩阵m*n,这个公式表示为:

    θF(x) =

    [

    aF(x)/ax11 ... aF(x)/ax1n

                    ...

    aF(x)/axm1 ... aF(x)/axmn

    ]

    tr(A)

    tr(A) = ∑aii      意思是,对某一矩阵求对角线和。A必须为方矩阵,这称为矩阵的迹。

    公式

    tr(A) = tr(AT) ; tr(AB) = tr(BA) ; tr(ABC) = tr(CAB) = tr( BCA ); 

    当A是实数时,tr(A) = A;

    tr(AB) = BT; ▽tr(ABATC) = CAB + CTABT;

    有了这些公式,下面就能开始使用向量来表示J(θ)了。以下是具体的思路:

    1.用 X 来表示训练数据的输入值,这也被称为design matrix;

    X = 

    [

    x1(1) ... xn(m)

            ... 

    xm(1) ... xn(m)

    ]

    2.用 θ (列向量) 右乘 X。就得到了 

    Xθ = 

    [

    x(1)θ = h(x(1))

     ... 

    x(m)θ = h(x(m))

    ]

    用 Y 表示 输出值得列向量。则

    Xθ - Y = 

    [

    h(x(1)) - y(1)

    ...

    h(x(m)) - y(m)

    ]

    3.最厉害的一步来了,利用上面的式子,J(θ) 可以表示成

    J(θ) = ( ∑j:1->m( h(x(j)) -y(j) )2 )/2 = (Xθ - Y)T(Xθ - Y) / 2;

    4.换个思路来看梯度下降,其实就是求一个θ值,使得该点上任何方向的偏导都为0。也就是▽θJ(θ) = 0。这里的0是一个向量。

    5.将式子展开,再利用前面的公式,最后得到神奇的:

    XTXθ = XTY     即     θ = (XTX)-1XTY

    总结:

    1.有很多公式需要自己证明。

    2.引入向量和矩阵是为了简化运算。完全去掉梯度下降的迭代过程。

    3.最后的结果应该是对h(x)有一定限制的,例如只适合线性的?我需要去验证一下。

  • 相关阅读:
    conn
    快速指数算法+Python代码
    扩展欧几里得算法+Python代码
    最速下降法+Matlab代码
    第二类生日攻击算法
    遗传算法+Python代码
    模糊聚类+Matlab代码
    数据库检索
    Spring Data Jpa依赖和配置
    上传Typora到博客园(解决图片缩放问题)
  • 原文地址:https://www.cnblogs.com/sskyy/p/2789274.html
Copyright © 2011-2022 走看看