zoukankan      html  css  js  c++  java
  • 梯度下降法

    有一个圆锥曲面:

    [z(x,y) = sqrt {x^2 + y^2} ]

    它的开口向上, 最低点为原点(O). (z)方向的俯视图如下:

    现想象你自己是一个有意识的能移动的小质点, 站在曲面上一个非原点的位置((x, y))上, 如图上的小红圈所示. 你看不到整个曲面的全貌, 能看到的只有以你为中心的,半径为(r)(为投影在xoy平面上的值)的水平视野内的地貌. 因为曲面是光滑的, 所以地貌全部内容就是坡度. 现在你要做的是, 根据眼前能看到的地貌, 一步一步走, 步长为(r), 以最少的步数到达最低点. 注意, 因为你是质子, 本身很小, 所以你的水平视野半径(r)也会很小很小, 几乎为0.

    你唯一能做的就是走一步看一步. 每一步能依据的信息就是各个方向的坡度. 用射线段(l)表示方向, 起点是你自己, 长度为(r), 与(x)轴的夹角为(alpha). 假如你选择了(alpha)方向, 则走完这一步之后, 你的xoy坐标为((x + rcosalpha, y + rsinalpha)), 高度(z)变为(z(x + rcosalpha, y + rsinalpha)). 你唯一知晓的坡度, 即高度变化率, 可以量化为:

    [坡度 = frac {z(x + rcosalpha, y + rsinalpha)}{r} = frac {partial z}{partial x} cosalpha + frac {partial z}{partial y} sinalpha ]

    在微积分里, 它也被称为(z)沿射线段(l)的方向导数, 用(frac {partial z}{partial l})表示. 当然, 你能观察到的只是它的数值, 而非表达式.
    将其写成两个向量的内积形式:

    [frac {partial z}{partial l} = (frac {partial z}{partial x}, frac {partial z}{partial y})(cosalpha, sinalpha)^T = grad^T n ]

    (grad = (frac {partial z}{partial x}, frac {partial z}{partial y})^T)也称为(z)((x, y))处的梯度. (n=(cosalpha, sinalpha)^T)则是(l)的单位方向向量.
    因为(Delta z propto frac {partial z}{partial l} = grad^T n), 所以:

    • (n)(grad)同向时, (grad^T n)为正数最大. 若沿(n)方向走一步, (z)值最大限度的增大.
    • (n)(grad)反向时, (grad^T n)为负数最大. 若沿(n)方向走一步, (z)值最大限度的降低.

    由于你的目的是往下走, 所以应该选择(-grad)方向. 每走一步, (x与y)的变化方式为:

    [(x, y) gets (x, y) - rfrac {grad^T}{||grad||} ]

    嗯, 记住你现在还是个质子, 你的(r)很小很小. 如果你离目的地(原点)还很远的话, 要费很多很多极多的步子才能到达. 切换到实际应用中求最小值点的场景, 就意味着很长很长的计算时间. 所以往往不是将(r)固定为一个极小的值, 而是将(frac r{|grad|})固定为一个值: (lr), 称作为step size. 在机器学习里就是learning rate, 学习速率. 所以上式改为:

    [(x, y) gets (x, y) - lr*(frac {partial z}{partial x}, frac {partial z}{partial y}) ]

    路径如图中红线所示:

    这种数值方法又叫做(单纯的)牛顿梯度下降法, 用于求最小值(点), 可以放心的推广到更高维空间. 不过有一个前提是目标函数是凹的, 即乘以(-1)后是凸的. 不然, 最后有可能会停留在局部最优而非全局最优.

    用于画出路径的matlab代码:

    close all;
    phi = pi/6;
    a = -pi:.05*pi:pi;
    r = 0: .1: 2;
    [A, R] = meshgrid(a, r);
    X = R.* cos(A);
    Y = R.* sin(A);
    Z = cot(phi) * sqrt(X.^2 + Y.^2);
    surf(X, Y, Z);
    hold on;
    plot3([1],[1], cot(phi)*sqrt(2), 'ro');
    alpha(.8);
    
    Xs = [];
    Ys = [];
    Zs = [];
    lr = 0.001;
    x = 1;
    y = 1;
    %z = cot(phi) * sqrt(x^2 + x^2);
    for i = 1:10^4
        x = x  - lr * x / sqrt(x^2 + y^2);
        y = y  - lr * y / sqrt(x^2 + y^2);
        z = cot(phi) * sqrt(x^2 + x^2)
     %   plot3(x,y, z, 'r.');
        Xs = [Xs, x];
        Ys = [Ys, y];
        Zs = [Zs, z];
    end
     plot3(Xs, Ys, Zs, 'r.');
    
    
  • 相关阅读:
    初学flask
    第一次使用pyqt5解决的几个小问题
    一些细节
    关于random
    go语言 方法
    go 语言 struct 另类构造函数 继承
    go 语言 链表 的增删改查
    go 语言 链表
    go 语言struct
    无题
  • 原文地址:https://www.cnblogs.com/dengdan890730/p/5557024.html
Copyright © 2011-2022 走看看