zoukankan      html  css  js  c++  java
  • 神经网络入门 第3章 S函数

     

        前言

        神经网络是一种很特别的解决问题的方法。本书将用最简单易懂的方式与读者一起从最简单开始,一步一步深入了解神经网络的基础算法。本书将尽量避开让人望而生畏的名词和数学概念,通过构造可以运行的Java程序来实践相关算法。

        关注微信号“逻辑编程"来获取本书的更多信息。

        上一章我们构造的神经元输入和输出没有任何限制,当我们组成神经网络时就会造成很多问题。输出值过大不但容易造成溢出,而且很很多时候是无意义的。比如,判断一个图片是否写着数字“2”,我们需要的是它是“2”的概率,[0-1]范围的实数是最合适的。

        因此,我们需要对我们的神经元输出作出限制。把上一章的函数[y=w*x+b]的输出值通过一个函数变换到[0-1]范围,就可以达到我们的目的。其它逻辑与上一章一致。

        那么有什么函数可以把任意实数值映射到[0-1]范围呢?这样的函数不止一种。在这里我们加入一个比较常用的,叫做S函数(Sigmoid)。

        

        

        当t等于无穷小时函数值趋近0;t值趋近无穷大时函数值趋近于1;0处函数值是0.5。其函数图形如下:

        将上一章的简单线性函数替换S函数的自变量我们得到这个函数:

        f(x) = 1/(1 + exp(-w*x - b))

    其中w决定了S函数在中间处的斜率包括变化方向,b决定了S函数在x轴方向的平移量。下图中展示了几个不同参数的S曲线的形状:

        

        从图中我们可以看到S曲线的一些特征:斜率非常大时类似于二值函数;可以由参数决定从0到1或者从1到0的不同方向;可以平移。S函数是神经网络的一个关键点。我们后面会讨论它的重要性。本章中我们将关注应用S函数之后的单个神经元。

        

        很明显这个函数无法完成我们上一章说的直线拟合,不过它依然可以通过训练拟合到它自己的参数。这就是本章我们要解决的问题。现在我们来继承上一章的Java类,只要覆盖几个函数即可。

        首先覆盖f函数,抽取上一章的f函数成z函数,然后加个S函数:


    public double f(double x) {
    return sigmoid(z(x));
    }

    double z(double x) {
    return x * weight + bias;
    }

    double sigmoid(double z) {
    return 1.0 / (1.0 + Math.exp(-z));
    }

        经过上边的变换,f的输出被影射到[0-1]范围内了。f函数是通过两个函数链接完成的。

        函数变了,其梯度函数也不一样了。比如y(z)=2*z表示一条斜率为2的斜线;z(x)=3*x表示一条斜率为3的斜线。那么y(x)=2*(3*x)=6*x,其斜率是2*3。可见函数的斜率具有链式规则。所以我们只要把S函数的导数乘以z函数的导数就可以了。这里的z函数是我们上一章的f函数。所以我们的梯度函数现在是这样的:

        

    public double[] gradient(double x, double y) {
    double[] g = super.gradient(x, y);
    double z = z(x);
    double dz = dz(z);
    g[0] *= dz;
    g[1] *= dz;
    return g;
    }

    protected double dz(double z) {
    return sigmoid(z) * (1 - sigmoid(z));
    }

        最后,我们覆盖获得测试目标的函数并对入口函数稍作修改,使用新的神经元类:


    protected SingleNeuron getTarget() {
    return new SingleSigmoidNeuron(3, 3);
    }


    public static void main(String... args) {
    SingleSigmoidNeuron n = new SingleSigmoidNeuron(0, 0);
       double rate = 5;
    int epoch = 100;
    int trainingSize = 20;
    for (int i = 0; i < epoch; i++) {
    double[][] data = n.generateTrainingData(trainingSize);
    n.train(data, rate);
    System.out.printf("Epoch: %3d,  W: %f, B: %f ", i, n.weight, n.bias);
    }
    }

    我们的目标(w,b)与上一章一样,产生训练数据的函数也不做修改。现在我们运行一下新的程序:

    Epoch:  90,  W: 2.915254, B: 2.966388 

    Epoch:  91,  W: 2.918170, B: 2.968286 

    Epoch:  92,  W: 2.919169, B: 2.968809 

    Epoch:  93,  W: 2.919169, B: 2.968809 

    Epoch:  94,  W: 2.919169, B: 2.968809 

    Epoch:  95,  W: 2.919169, B: 2.968809 

    Epoch:  96,  W: 2.920288, B: 2.969410 

    Epoch:  97,  W: 2.939107, B: 2.992186 

    Epoch:  98,  W: 2.939107, B: 2.992186 

    Epoch:  99,  W: 2.939229, B: 2.992233 

        呀吼!还是能训练的到我们的目标(w,b)。实验成功!

        读者可以思考一下S函数有什么替代品。答案肯定是有的。您可以自己把S函数替换成另外一个符合条件的函数,看看训练结果和效率怎么样。

    最后完整代码如下:

    package com.luoxq.ann.single; 
    /**
    * Created by luoxq on 17/4/9.
    */
    public class SingleSigmoidNeuron extends SingleNeuron {

    public SingleSigmoidNeuron(double weight, double bias) {
    super(weight, bias);
    }

    public double f(double x) {
    return sigmoid(z(x));
    }

    double z(double x) {
    return x * weight + bias;
    }

    double sigmoid(double z) {
    return 1.0 / (1.0 + Math.exp(-z));
    }

    // c = w*x + b - y
       public double[] gradient(double x, double y) {
    double[] g = super.gradient(x, y);
    double z = z(x);
    double dz = dz(z);
    g[0] *= dz;
    g[1] *= dz;
    return g;
    }

    protected double dz(double z) {
    return sigmoid(z) * (1 - sigmoid(z));
    }

    protected SingleNeuron getTarget() {
    return new SingleSigmoidNeuron(3, 3);
    }


    public static void main(String... args) {
    SingleSigmoidNeuron n = new SingleSigmoidNeuron(0, 0);
    //target: y = 3*x + 3;
           double rate = 5;
    int epoch = 100;
    int trainingSize = 20;
    for (int i = 0; i < epoch; i++) {
    double[][] data = n.generateTrainingData(trainingSize);
    n.train(data, rate);
    System.out.printf("Epoch: %3d,  W: %f, B: %f ", i, n.weight, n.bias);
    }
    }
    }

    这一章我们使用了S函数和链式求导法则。要了解其它章节,请关注微信订阅号逻辑编程。

  • 相关阅读:
    pgspider sqlite mysql docker 镜像
    pgspider docker 镜像
    pgspider基于pg 的高性能数据可视化sql 集群引擎
    diesel rust orm 框架试用
    golang 条件编译
    Performance Profiling Zeebe
    bazel 学习一 简单java 项目运行
    一个好用node http keeplive agnet
    gox 简单灵活的golang 跨平台编译工具
    mailhog 作为smtp server mock工具
  • 原文地址:https://www.cnblogs.com/javadaddy/p/6732900.html
Copyright © 2011-2022 走看看