zoukankan      html  css  js  c++  java
  • BP神经网络在python下的自主搭建梳理

    本实验使用mnist数据集完成手写数字识别的测试。识别正确率认为是95%

    完整代码如下:

      1 #!/usr/bin/env python
      2 # coding: utf-8
      3 
      4 # In[1]:
      5 
      6 
      7 import numpy
      8 import scipy.special
      9 import matplotlib.pyplot
     10 
     11 
     12 # In[2]:
     13 
     14 
     15 class neuralNetwork:
     16     def __init__(self, inputNodes, hiddenNodes, outputNodes,learningRate):
     17         self.iNodes = inputNodes
     18         self.oNodes = outputNodes
     19         self.hNodes = hiddenNodes
     20         self.lr = learningRate
     21         self.wih = numpy.random.normal (0.0, pow(self.hNodes,-0.5), (self.hNodes, self.iNodes))
     22         self.who = numpy.random.normal (0.0, pow(self.oNodes,-0.5), (self.oNodes, self.hNodes))
     23         
     24         self.activation_function = lambda x: scipy.special.expit(x)
     25         #print(self.wih)
     26         pass
     27     
     28     def train(self,inputs_list, target_list):
     29         inputs = numpy.array(inputs_list, ndmin=2).T
     30         targets = numpy.array(target_list, ndmin=2).T
     31         #print(inputs)
     32         #print(targets)
     33         hidden_inputs = numpy.dot(self.wih,inputs)
     34         #print(self.wih.shape)
     35         #print(inputs.shape)
     36         hidden_outputs = self.activation_function(hidden_inputs)
     37         #print(hidden_inputs)
     38         final_inputs = numpy.dot(self.who,hidden_outputs)
     39         #print(hidden_outputs)
     40         final_outputs = self.activation_function(final_inputs)
     41         
     42         output_errors = targets - final_outputs
     43         hidden_errors = numpy.dot(self.who.T,output_errors)
     44         self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)),numpy.transpose(hidden_outputs))
     45         self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),numpy.transpose(inputs))
     46         pass
     47     
     48     def query(self, inputs_list):
     49         inputs = numpy.array(inputs_list, ndmin=2).T
     50         hidden_inputs = numpy.dot(self.wih,inputs)
     51         hidden_outputs = self.activation_function(hidden_inputs)
     52         final_inputs = numpy.dot(self.who,hidden_outputs)
     53         final_outpus = self.activation_function(final_inputs)
     54         return final_outpus
     55         pass
     56     
     57 
     58 
     59 # In[3]:
     60 
     61 
     62 inputNodes = 784
     63 outputNodes = 10
     64 hiddenNodes = 100
     65 learningRate = 0.1
     66 nN = neuralNetwork(inputNodes, hiddenNodes, outputNodes, learningRate)
     67 
     68 
     69 # In[4]:
     70 
     71 
     72 data_file = open("mnist_train.csv",'r')
     73 data_list = data_file.readlines()
     74 data_file.close()
     75 
     76 
     77 # In[5]:
     78 
     79 
     80 epochs = 1
     81 for e in range(epochs) :
     82     for record in data_list:
     83         all_values = record.split(',')
     84         inputs = numpy.asfarray( all_values [1:])/255.0*0.99+0.01
     85         targets = numpy.zeros(outputNodes) + 0.01
     86         targets[int (all_values[0])] = 0.99
     87         nN.train(inputs,targets)
     88         pass
     89     pass
     90 
     91 
     92 # In[6]:
     93 
     94 
     95 test_data_file = open("mnist_test.csv",'r')
     96 test_data_list = test_data_file.readlines()
     97 test_data_file.close()
     98 
     99 
    100 # In[7]:
    101 
    102 
    103 scorecard = []
    104 for record in test_data_list:
    105     all_values = record.split(',')
    106     correct_label = int(all_values[0])
    107     inputs = numpy.asfarray( all_values [1:])/255.0*0.99+0.01
    108     outputs = nN.query(inputs)
    109     label = numpy.argmax(outputs)
    110     if(label == correct_label):
    111         scorecard.append(1)
    112     else:
    113         scorecard.append(0)
    114         pass
    115     pass
    116 
    117 
    118 # In[8]:
    119 
    120 
    121 scorecard_array = numpy.asarray(scorecard)
    122 print ("performance = " ,scorecard_array.sum()/scorecard_array.size)
    123 
    124 
    125 # In[9]:
    126 
    127 
    128 import scipy.misc
    129 img_array = scipy.misc.imread('test.png',flatten="True")
    130 img_data = 255.0 - img_array . reshape(784)
    131 img_data = (img_data /255.0 * 0.99 ) + 0.01
    132 op=nN.query(img_data)
    133 print(op)
    134 print(numpy.argmax(op))
    135 
    136 
    137 # In[10]:
    138 
    139 
    140 all_values = data_list[1].split(',')
    141 image_array = numpy.asfarray( all_values [1:]).reshape((28,28))
    142 matplotlib.pyplot.imshow(image_array, cmap = 'Greys',interpolation='None')
    View Code

    IN[9]到IN[10]的代码分别用于测试自己制作的数字识别效果和显示图像。可去掉。

    代码运行过程需要mnist数据集,链接:https://pan.baidu.com/s/120GTdZ8Tivkp1KD9VQ_XeQ

     BP神经网络的结构:https://www.cnblogs.com/bai2018/p/10353768.html

    在输入层的神经元数据选取上,和像素数量一致。MNIST采用28X28的像素点,则输入层的神经元数量为28*28=784个

    输入层和隐层,输出层和隐层之间的权值选取为随机数。使用正态分布的随机数较好。

    隐层的神经元数量合适即可,取值为经验法,假设为100个

    输出层神经元表示数据0-9,则使用10个神经元,分别表示数字0-9的可能性概率。

    训练过程中使用的学习效率,取0.2吧。。。

    将权重,各层神经元值,误差等,表示为矩阵数据进行处理。

    正向传递数据查询结果,误差的反向传递改变权重等过程,涉及到的数学推导:https://www.cnblogs.com/bai2018/p/10353814.html

  • 相关阅读:
    Spring Boot任务管理之定时任务
    Spring Boot任务管理之有返回值异步任务调用
    Spring Boot任务管理之无返回值异步任务调用
    Spring Boot整合Thymeleaf
    IDEA配置maven详细教程
    虚拟机运行centos设置固定IP
    django.db.utils.IntegrityError: (1048, "Column 'spu_id' cannot be null")关于RESTframework使用序列化器报错问题
    RuntimeWarning: DateTimeField User.date_joined received a naive datetime (2020-08-01 00:00:00) while time zone support is active. warnings.warn("DateTimeField %s.%s received a naive datetime "问题
    ModuleNotFoundError: No module named 'PIL'问题
    ImportError: cannot import name 'Feature' from 'setuptools' (D:python_learnmeiduo_projectenvlibsite-packagessetuptools\__init__.py)问题
  • 原文地址:https://www.cnblogs.com/bai2018/p/10353747.html
Copyright © 2011-2022 走看看