zoukankan      html  css  js  c++  java
  • 训练网络

      解决训练任务,包括两部分内容

    第一部分:针对给定的训练样本计算输出。这与query()函数所做的工作没什么区别。

    第二部分:将计算所得到的输出与期望的目标值做对比,使用差值来指导网络权重的更新。

    其中,第一部分的代码如下所示:

     1     def train(self,input_list,target_list):
     2         # 转换输入输出列表到二维数组
     3         inputs = numpy.array(input_list, ndmin=2).T
     4         targets = numpy.array(target_list,ndmin= 2).T
     5         # 计算到隐藏层的信号
     6         hidden_inputs = numpy.dot(self.wih, inputs)
     7         # 计算隐藏层输出的信号
     8         hidden_outputs = self.activation_function(hidden_inputs)
     9         # 计算到输出层的信号
    10         final_inputs = numpy.dot(self.who, hidden_outputs)
    11         final_outputs = self.activation_function(final_inputs)

    这部分与query()中的区别在于多了一个期望值,因为我们需要期望值来训练网络,所以这部分必不可少。

    第二部分:

    1.首先需要计算误差,也就是期望值减去输出的实际值,以此可表示为:

    output_errors = targets - final_outputs

    2.那么,如何根据得到的输出误差来更新隐藏层和输出层,输入层和隐藏层之间的权重呢?首先,输出的误差来源于隐藏层传播的误差,隐藏层各个节点的误差具体分配多少呢?

                        errorshidden = weightsThidden_output * errorsoutput

    使用python上式可表示为:

    hidden_errors = numpy.dot(self.who.T,output_errors)

    其次,利用公式更新权重:

          ΔWj,k = α * Ek * sigmod(Ok) * (1 - sigmod(OK)) * OjT

    使用python可表示为:

    1 #隐藏层和输出层权重更新
    2         self.who += self.lr * numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),
    3                                         numpy.transpose(hidden_outputs))
    4         #输入层和隐藏层权重更新
    5         self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),
    6                                         numpy.transpose(inputs))

    因此,完整的神经网络代码表示为:

     1 import numpy
     2 import scipy.special
     3 
     4 # 神经网络类定义
     5 class NeuralNetwork():
     6     # 初始化神经网络
     7     def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
     8         # 设置输入层节点,隐藏层节点和输出层节点的数量
     9         self.inodes = inputnodes
    10         self.hnodes = hiddennodes
    11         self.onodes = outputnodes
    12         # 学习率设置
    13         self.lr = learningrate
    14         # 权重矩阵设置 正态分布
    15         self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
    16         self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))
    17         # 激活函数设置,sigmod()函数
    18         self.activation_function = lambda x: scipy.special.expit(x)
    19         pass
    20 
    21     # 训练神经网络
    22     def train(self,input_list,target_list):
    23         # 转换输入输出列表到二维数组
    24         inputs = numpy.array(input_list, ndmin=2).T
    25         targets = numpy.array(target_list,ndmin= 2).T
    26         # 计算到隐藏层的信号
    27         hidden_inputs = numpy.dot(self.wih, inputs)
    28         # 计算隐藏层输出的信号
    29         hidden_outputs = self.activation_function(hidden_inputs)
    30         # 计算到输出层的信号
    31         final_inputs = numpy.dot(self.who, hidden_outputs)
    32         final_outputs = self.activation_function(final_inputs)
    33 
    34         output_errors = targets - final_outputs
    35         hidden_errors = numpy.dot(self.who.T,output_errors)
    36 
    37         #隐藏层和输出层权重更新
    38         self.who += self.lr * numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),
    39                                         numpy.transpose(hidden_outputs))
    40         #输入层和隐藏层权重更新
    41         self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),
    42                                         numpy.transpose(inputs))
    43         pass
    44     # 查询神经网络
    45     def query(self, input_list):
    46         # 转换输入列表到二维数组
    47         inputs = numpy.array(input_list, ndmin=2).T
    48         # 计算到隐藏层的信号
    49         hidden_inputs = numpy.dot(self.wih, inputs)
    50         # 计算隐藏层输出的信号
    51         hidden_outputs = self.activation_function(hidden_inputs)
    52         # 计算到输出层的信号
    53         final_inputs = numpy.dot(self.who, hidden_outputs)
    54         final_outputs = self.activation_function(final_inputs)
    55 
    56         return final_outputs
  • 相关阅读:
    基于redis实现滑动窗口式的短信发送接口限流
    Linux 宝塔下的PHP如何与本地的nginx关联
    Linux 下php安装gd库
    Linux Mysql8重置密码
    PHP 无限分级类
    redis 缓存穿透,缓存雪崩,缓存击穿
    yii2 事务添加
    ConcurrentHashMap
    Volatile
    this引用的逸出
  • 原文地址:https://www.cnblogs.com/carlber/p/9693600.html
Copyright © 2011-2022 走看看