zoukankan      html  css  js  c++  java
  • 神经网络的训练和测试 python

      承接上一节,神经网络需要训练,那么训练集来自哪?测试的数据又来自哪?

      《python神经网络编程》一书给出了训练集,识别图片中的数字。测试集的链接如下:

      https://raw.githubusercontent.com/makeyourownneuralnetwork/makeyourownneuralnetwork/master/mnist_dataset/mnist_test_10.csv

    为了方便,这只是一个小的测试集,才10个。

      训练集链接:https://raw.githubusercontent.com/makeyourownneuralnetwork/makeyourownneuralnetwork/master/mnist_dataset/mnist_train_100.csv

    这是包含100个数据的训练集。

      训练集和测试集的每段的第一个字母是期望的数字,每段剩余的文本是表示这个数字的像素集合,为784个数据。为了计算,我们要把文本转化为数字进行存放。把第一个数据当作期望数据,剩余的784个数据当作输入。因此输入节点设为784个。输出节点设为10个,因为要识别的是10个数据0到9。隐藏层节点选为100个,并没有进行科学的计算。

      

     1 import numpy
     2 import scipy.special
     3 import matplotlib.pyplot as plt
     4 import pylab
     5 # 神经网络类定义
     6 class NeuralNetwork():
     7     # 初始化神经网络
     8     def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
     9         # 设置输入层节点,隐藏层节点和输出层节点的数量
    10         self.inodes = inputnodes
    11         self.hnodes = hiddennodes
    12         self.onodes = outputnodes
    13         # 学习率设置
    14         self.lr = learningrate
    15         # 权重矩阵设置 正态分布
    16         self.wih = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
    17         self.who = numpy.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))
    18         # 激活函数设置,sigmod()函数
    19         self.activation_function = lambda x: scipy.special.expit(x)
    20         pass
    21 
    22     # 训练神经网络
    23     def train(self,input_list,target_list):
    24         # 转换输入输出列表到二维数组
    25         inputs = numpy.array(input_list, ndmin=2).T
    26         targets = numpy.array(target_list,ndmin= 2).T
    27         # 计算到隐藏层的信号
    28         hidden_inputs = numpy.dot(self.wih, inputs)
    29         # 计算隐藏层输出的信号
    30         hidden_outputs = self.activation_function(hidden_inputs)
    31         # 计算到输出层的信号
    32         final_inputs = numpy.dot(self.who, hidden_outputs)
    33         final_outputs = self.activation_function(final_inputs)
    34 
    35         output_errors = targets - final_outputs
    36         hidden_errors = numpy.dot(self.who.T,output_errors)
    37 
    38         #隐藏层和输出层权重更新
    39         self.who += self.lr * numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),
    40                                         numpy.transpose(hidden_outputs))
    41         #输入层和隐藏层权重更新
    42         self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),
    43                                         numpy.transpose(inputs))
    44         pass
    45     # 查询神经网络
    46     def query(self, input_list):
    47         # 转换输入列表到二维数组
    48         inputs = numpy.array(input_list, ndmin=2).T
    49         # 计算到隐藏层的信号
    50         hidden_inputs = numpy.dot(self.wih, inputs)
    51         # 计算隐藏层输出的信号
    52         hidden_outputs = self.activation_function(hidden_inputs)
    53         # 计算到输出层的信号
    54         final_inputs = numpy.dot(self.who, hidden_outputs)
    55         final_outputs = self.activation_function(final_inputs)
    56 
    57         return final_outputs
    58 
    59 # 设置每层节点个数
    60 input_nodes = 784
    61 hidden_nodes = 100
    62 output_nodes = 10
    63 # 设置学习率为0.3
    64 learning_rate = 0.3
    65 # 创建神经网络
    66 n = NeuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)
    67 
    68 #读取训练数据集 转化为列表
    69 training_data_file = open("D:/mnist_train_100.csv",'r')
    70 training_data_list = training_data_file.readlines();
    71 training_data_file.close()
    72 
    73 #训练神经网络
    74 for record in training_data_list:
    75     #根据逗号,将文本数据进行拆分
    76     all_values = record.split(',')
    77     #将文本字符串转化为实数,并创建这些数字的数组。
    78     inputs = (numpy.asfarray(all_values[1:])/255.0 * 0.99) + 0.01
    79     #创建用零填充的数组,数组的长度为output_nodes,加0.01解决了0输入造成的问题
    80     targets = numpy.zeros(output_nodes) + 0.01
    81     #使用目标标签,将正确元素设置为0.99
    82     targets[int(all_values[0])] = 0.99
    83     n.train(inputs,targets)
    84     pass
    85 
    86 #读取测试文件
    87 test_data_file = open("D:/mnist_test_10.csv",'r')
    88 test_data_list = test_data_file.readlines()
    89 test_data_file.close()
    90 
    91 all_values = test_data_list[0].split(',')
    92 print(all_values[0])  #输出目标值
    93 
    94 image_array = numpy.asfarray(all_values[1:]).reshape((28,28))
    95 print(n.query((numpy.asfarray(all_values[1:])/255.0*0.99)+0.01))#输出标签值
    96 plt.imshow(image_array,cmap='Greys',interpolation='None')#显示原图像
    97 pylab.show()

    输出情况:

      从结果可以看出,我们输入的目标值为7,结果中第7个标签所对应的值最大,表明了正确识别了目标值。和图片中的值一样。

  • 相关阅读:
    django虚拟环境中报E: 无法定位软件包 sqliteman
    创建django项目
    Django虚拟环境安装
    python学习笔记(三)
    python学习笔记(二)
    python学习笔记(一)
    python 类属性和实例属性
    决策树的基本ID3算法
    KNN算法的简单实现
    webClient
  • 原文地址:https://www.cnblogs.com/carlber/p/9707271.html
Copyright © 2011-2022 走看看