zoukankan      html  css  js  c++  java
  • 神经网络入门 第3章 S函数

     

        前言

        神经网络是一种很特别的解决问题的方法。本书将用最简单易懂的方式与读者一起从最简单开始,一步一步深入了解神经网络的基础算法。本书将尽量避开让人望而生畏的名词和数学概念,通过构造可以运行的Java程序来实践相关算法。

        关注微信号“逻辑编程"来获取本书的更多信息。

        上一章我们构造的神经元输入和输出没有任何限制,当我们组成神经网络时就会造成很多问题。输出值过大不但容易造成溢出,而且很很多时候是无意义的。比如,判断一个图片是否写着数字“2”,我们需要的是它是“2”的概率,[0-1]范围的实数是最合适的。

        因此,我们需要对我们的神经元输出作出限制。把上一章的函数[y=w*x+b]的输出值通过一个函数变换到[0-1]范围,就可以达到我们的目的。其它逻辑与上一章一致。

        那么有什么函数可以把任意实数值映射到[0-1]范围呢?这样的函数不止一种。在这里我们加入一个比较常用的,叫做S函数(Sigmoid)。

        

        

        当t等于无穷小时函数值趋近0;t值趋近无穷大时函数值趋近于1;0处函数值是0.5。其函数图形如下:

        将上一章的简单线性函数替换S函数的自变量我们得到这个函数:

        f(x) = 1/(1 + exp(-w*x - b))

    其中w决定了S函数在中间处的斜率包括变化方向,b决定了S函数在x轴方向的平移量。下图中展示了几个不同参数的S曲线的形状:

        

        从图中我们可以看到S曲线的一些特征:斜率非常大时类似于二值函数;可以由参数决定从0到1或者从1到0的不同方向;可以平移。S函数是神经网络的一个关键点。我们后面会讨论它的重要性。本章中我们将关注应用S函数之后的单个神经元。

        

        很明显这个函数无法完成我们上一章说的直线拟合,不过它依然可以通过训练拟合到它自己的参数。这就是本章我们要解决的问题。现在我们来继承上一章的Java类,只要覆盖几个函数即可。

        首先覆盖f函数,抽取上一章的f函数成z函数,然后加个S函数:


    public double f(double x) {
    return sigmoid(z(x));
    }

    double z(double x) {
    return x * weight + bias;
    }

    double sigmoid(double z) {
    return 1.0 / (1.0 + Math.exp(-z));
    }

        经过上边的变换,f的输出被影射到[0-1]范围内了。f函数是通过两个函数链接完成的。

        函数变了,其梯度函数也不一样了。比如y(z)=2*z表示一条斜率为2的斜线;z(x)=3*x表示一条斜率为3的斜线。那么y(x)=2*(3*x)=6*x,其斜率是2*3。可见函数的斜率具有链式规则。所以我们只要把S函数的导数乘以z函数的导数就可以了。这里的z函数是我们上一章的f函数。所以我们的梯度函数现在是这样的:

        

    public double[] gradient(double x, double y) {
    double[] g = super.gradient(x, y);
    double z = z(x);
    double dz = dz(z);
    g[0] *= dz;
    g[1] *= dz;
    return g;
    }

    protected double dz(double z) {
    return sigmoid(z) * (1 - sigmoid(z));
    }

        最后,我们覆盖获得测试目标的函数并对入口函数稍作修改,使用新的神经元类:


    protected SingleNeuron getTarget() {
    return new SingleSigmoidNeuron(3, 3);
    }


    public static void main(String... args) {
    SingleSigmoidNeuron n = new SingleSigmoidNeuron(0, 0);
       double rate = 5;
    int epoch = 100;
    int trainingSize = 20;
    for (int i = 0; i < epoch; i++) {
    double[][] data = n.generateTrainingData(trainingSize);
    n.train(data, rate);
    System.out.printf("Epoch: %3d,  W: %f, B: %f ", i, n.weight, n.bias);
    }
    }

    我们的目标(w,b)与上一章一样,产生训练数据的函数也不做修改。现在我们运行一下新的程序:

    Epoch:  90,  W: 2.915254, B: 2.966388 

    Epoch:  91,  W: 2.918170, B: 2.968286 

    Epoch:  92,  W: 2.919169, B: 2.968809 

    Epoch:  93,  W: 2.919169, B: 2.968809 

    Epoch:  94,  W: 2.919169, B: 2.968809 

    Epoch:  95,  W: 2.919169, B: 2.968809 

    Epoch:  96,  W: 2.920288, B: 2.969410 

    Epoch:  97,  W: 2.939107, B: 2.992186 

    Epoch:  98,  W: 2.939107, B: 2.992186 

    Epoch:  99,  W: 2.939229, B: 2.992233 

        呀吼!还是能训练的到我们的目标(w,b)。实验成功!

        读者可以思考一下S函数有什么替代品。答案肯定是有的。您可以自己把S函数替换成另外一个符合条件的函数,看看训练结果和效率怎么样。

    最后完整代码如下:

    package com.luoxq.ann.single; 
    /**
    * Created by luoxq on 17/4/9.
    */
    public class SingleSigmoidNeuron extends SingleNeuron {

    public SingleSigmoidNeuron(double weight, double bias) {
    super(weight, bias);
    }

    public double f(double x) {
    return sigmoid(z(x));
    }

    double z(double x) {
    return x * weight + bias;
    }

    double sigmoid(double z) {
    return 1.0 / (1.0 + Math.exp(-z));
    }

    // c = w*x + b - y
       public double[] gradient(double x, double y) {
    double[] g = super.gradient(x, y);
    double z = z(x);
    double dz = dz(z);
    g[0] *= dz;
    g[1] *= dz;
    return g;
    }

    protected double dz(double z) {
    return sigmoid(z) * (1 - sigmoid(z));
    }

    protected SingleNeuron getTarget() {
    return new SingleSigmoidNeuron(3, 3);
    }


    public static void main(String... args) {
    SingleSigmoidNeuron n = new SingleSigmoidNeuron(0, 0);
    //target: y = 3*x + 3;
           double rate = 5;
    int epoch = 100;
    int trainingSize = 20;
    for (int i = 0; i < epoch; i++) {
    double[][] data = n.generateTrainingData(trainingSize);
    n.train(data, rate);
    System.out.printf("Epoch: %3d,  W: %f, B: %f ", i, n.weight, n.bias);
    }
    }
    }

    这一章我们使用了S函数和链式求导法则。要了解其它章节,请关注微信订阅号逻辑编程。

  • 相关阅读:
    jQuery.getJSON的缓存问题的解决办法
    MFC Tab Control控件的详细使用
    JavaScript 闭包深入理解(closure)
    STL中sort函数用法简介
    STL中qsort的七种用法
    学习Javascript闭包(Closure)
    使用 Visual Studio 分析器找出应用程序瓶颈
    各种语言性能测试工具一览表
    Javascript 链式作用域
    MessageBox、::MessageBox 、AfxMessageBox三者的区别 .
  • 原文地址:https://www.cnblogs.com/javadaddy/p/6732900.html
Copyright © 2011-2022 走看看