zoukankan      html  css  js  c++  java
  • 神经网络简易的框架

    最近学习了神经网络的原理,以下是学习的成果,分享出来

      1 import numpy
      2 #scipy.special for the sigmoid function expit()
      3 import scipy.special #专用扩展函数
      4 #libray for plotting arrays
      5 import matplotlib.pyplot  #用于绘制数组的libray
      6 #ensure the plots are inside this notebook,not anexternal windows  #在页面内呈现
      7 %matplotlib inline
      8 
      9 
     10 #神经网络框架类
     11 class neuralNetwork:
     12     #初始化神经网络    输入     隐藏层     输出   学习率
     13     def  __init__(self,inputnodes,hiddennondes,outputnodes,learningrate):
     14         self.inodes=inputnodes
     15         self.hondes=hiddennondes
     16         self.onodes=outputnodes
     17        
     18         #链接权值矩阵 link weight matrices ,wih and who
     19         #weights inside the arrays are w_i_j,where link is from node in to node j in the next layer
     20         #数组内的权重是w_i_j,其中link是从节点in到下一层的节点j
     21         #self.wih=(numpy.random.rand(self.hondes,self.inodes)-0.5)
     22         #self.who=(numpy.random.rand(self.onodes,self.hondes)-0.5)
     23         
     24         #初始化权重
     25         self.wih = numpy.random.normal(0.0,pow(self.hondes,-0.5),(self.hondes,self.inodes))
     26         self.who = numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hondes))
     27         
     28         #学习率初始
     29         self.lr=learningrate
     30         
     31         #激活函数是s型函数
     32         self.activation_function = lambda x: scipy.special.expit(x)
     33         
     34         pass
     35     
     36     #训练神经网络  train the neural network
     37     def train(self,inputs_list,targets_list):
     38         #将输入列表转换为二维数组
     39         inputs =numpy.array(inputs_list,ndmin=2).T
     40         targets = numpy.array(targets_list,ndmin=2).T
     41         
     42         #计算信号进入隐含层
     43         hidden_inputs = numpy.dot(self.wih,inputs)
     44         #计算从隐含层出现的信号
     45         hidden_outputs = self.activation_function(hidden_inputs)
     46         
     47         #计算信号进入输出层
     48         final_inputs = numpy.dot(self.who,hidden_outputs)
     49         #计算从输出层出现的信号
     50         final_outputs =self.activation_function(final_inputs)
     51         
     52         #误差计算
     53         output_errors = targets-final_outputs
     54         #隐含层误差馈
     55         hidden_errors =numpy.dot(self.who.T,output_errors)
     56         
     57         #隐含层权重更新
     58         self.who += self.lr*numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),numpy.transpose(hidden_outputs))
     59         
     60         #输入层权重更新
     61         self.wih += self.lr*numpy.dot((hidden_errors*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs))
     62         
     63         
     64         
     65         pass
     66     #查询
     67     def query(self,inputs_list):
     68         #将输入列表转换为二维数组
     69         inputs = numpy.array(inputs_list,ndmin=2).T
     70         
     71         #计算信号到隐藏层
     72         hidden_inputs = numpy.dot(self.wih,inputs)
     73         #计算从隐含层出现的信号
     74         hidden_outputs =self.activation_function(hidden_inputs)
     75         
     76         #计算信号到最终的输出层
     77         final_inputs = numpy.dot(self.who,hidden_outputs)
     78         #计算从最终输出层出现的信号
     79         final_outputs = self.activation_function(final_inputs)
     80         
     81         return final_outputs
     82         
     83 #数字输入,隐藏层、输出层
     84 input_nodes= 784
     85 hidden_nodes=500
     86 output_nodes=10
     87 
     88 
     89 #学习率设置  learning rate
     90 learning_rate = 0.1
     91 
     92 
     93 #创建神经网络实例
     94 n=neuralNetwork(input_nodes,hidden_nodes,output_nodes,learning_rate)
     95 
     96 
     97 #打开训练集
     98 training_data_file = open("mnist_dataset/mnist_train.csv","r")
     99 training_data_list = training_data_file.readlines()
    100 training_data_file.close()
    101 
    102 
    103 #训练神经网络
    104 
    105 
    106 #网络执行情况的记分卡最初为空
    107 #scorecard =[]
    108 
    109 
    110 #检查训练数据集中的所有记录
    111 def traIning():
    112     for record in training_data_list:
    113         all_values = record.split(',')#用,分割为数组
    114         inputs  = (numpy.asfarray(all_values[1:])/255.0*0.99)+0.01  #整体偏移到0.01到0.99之间  规模和转移输入
    115         targets =numpy.zeros(output_nodes)+0.01  #创建目标输出值 在0.01-0.99之间
    116         targets[int(all_values[0])] =0.99  #是该记录的目标标签
    117         n.train(inputs,targets)
    118         pass
    119 
    120 
    121 #epochs是训练数据集用于训练的次数
    122 epochs = 5
    123 for e in range(epochs):
    124     traIning()
    125     
    126     
    127 print("训练完成!")
    128 
    129 
    130 
    131 
    132 test_data_file  = open("mnist_dataset/mnist_test.csv",'r')#打开测试文件
    133 test_data_list =test_data_file.readlines() #读取数据
    134 test_data_file.close()#完毕文件
    135 
    136 
    137 #网络执行情况的记分卡最初为空
    138 scorecard =[]
    139 for record in test_data_list:
    140     all_values =record.split(',')
    141     correct_label = int(all_values[0]) #训练集中的数字
    142     inputs =(numpy.asfarray(all_values[1:])/25.0*0.99)+0.01  #权重#整体偏移到0.01到0.99之间
    143     outputs =n.query(inputs)#进行查询
    144     label = numpy.argmax(outputs)  #数字 标签
    145     #print(label,"network's answer")
    146     if (label == correct_label):  #
    147         scorecard.append(1)
    148     else:
    149         scorecard.append(0)
    150         pass
    151 all_value =test_data_list[5].split(',')  #查看对应的数字
    152 print(all_value[0])
    153 
    154 
    155 scorecard_array = numpy.asarray(scorecard)
    156 print("performance=",scorecard_array.sum()/scorecard_array.size) #计算准确率
    MNIST数据库提供的手写数字,训练集与测试集地址
  • 相关阅读:
    【SHOI2002】百事世界杯之旅
    【LGOJ 3384】树链剖分
    [20191006机房测试] 括号序列
    [20191006机房测试] 矿石
    【SHOI2012】回家的路
    [20191005机房测试] Seat
    [20191005机房测试] Silhouette
    每年六一儿童节,牛客都会准备一些小礼物去看望孤儿院的小朋友,今年亦是如此。HF作为牛客的资深元老,自然也准备了一些小游戏。其中,有个游戏是这样的:首先,让小朋友们围成一个大圈。然后,他随机指定一个数m,让编号为0的小朋友开始报数。每次喊到m-1的那个小朋友要出列唱首歌,然后可以在礼品箱中任意的挑选礼物,并且不再回到圈中,从他的下一个小朋友开始,继续0...m-1报数....这样
    fgets函数读取最后一行的时候为什么会重复
    c语言中返回的变量地址,其物理地址在?(刨根问底)
  • 原文地址:https://www.cnblogs.com/uge3/p/13305290.html
Copyright © 2011-2022 走看看