zoukankan      html  css  js  c++  java
  • 手写神经网络Python深度学习

    import numpy
    import scipy.special
    import matplotlib.pyplot as plt
    import scipy.misc
    import glob
    import imageio
    import scipy.ndimage
    
    class neuralNetWork:
      def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate):
        self.inodes = inputnodes
        self.hnodes = hiddennodes
        self.onodes = outputnodes
    
        self.wih = numpy.random.normal(0.0,pow(self.inodes, -0.5),(self.hnodes,self.inodes))
        self.who = numpy.random.normal(0.0,pow(self.hnodes, -0.5),(self.onodes,self.hnodes))
        
        self.lr = learningrate
    
        self.activation_function = lambda x: scipy.special.expit(x) # 激活函数
        self.inverse_activation_function = lambda x: scipy.special.logit(x) # 反向查询log激活函数
    
      def train(self,inputs_list,targets_list):
        inputs = numpy.array(inputs_list,ndmin=2).T
        targets = numpy.array(targets_list,ndmin=2).T
    
        hidden_inputs = numpy.dot(self.wih,inputs)
        hidden_outputs = self.activation_function(hidden_inputs)
    
        final_inputs = numpy.dot(self.who,hidden_outputs)
        final_outputs = self.activation_function(final_inputs)
    
        output_errors = targets - final_outputs
        hidden_errors = numpy.dot(self.who.T,output_errors)
    
        self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)),numpy.transpose(hidden_outputs))
        self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),numpy.transpose(inputs))
    
      def query(self,inputs_list):
        inputs = numpy.array(inputs_list,ndmin=2).T
    
        hidden_inputs = numpy.dot(self.wih,inputs)
        hidden_outputs = self.activation_function(hidden_inputs)
        final_inputs = numpy.dot(self.who,hidden_outputs)
        final_outputs = self.activation_function(final_inputs)
    
        return final_outputs
      def backquery(self, targets_list):
        final_outputs = numpy.array(targets_list, ndmin=2).T
      
        final_inputs = self.inverse_activation_function(final_outputs)
        hidden_outputs = numpy.dot(self.who.T, final_inputs)
        
        hidden_outputs -= numpy.min(hidden_outputs)
        hidden_outputs /= numpy.max(hidden_outputs)
        hidden_outputs *= 0.98
        hidden_outputs += 0.01
    
        hidden_inputs = self.inverse_activation_function(hidden_outputs)
        inputs = numpy.dot(self.wih.T, hidden_inputs)
        inputs -= numpy.min(inputs)
        inputs /= numpy.max(inputs)
        inputs *= 0.98
        inputs += 0.01
        
        return inputs
      
    
    input_nodes = 784
    hidden_nodes = 200
    output_nodes = 10
    learing_rate = 0.1
    n = neuralNetWork(input_nodes,hidden_nodes,output_nodes,learing_rate)
    
    train_data_file = open('mnist_train.csv', 'r')
    train_data_list = train_data_file.readlines()
    train_data_file.close()
    
    epochs = 5
    for e in range(epochs):
      for record in train_data_list:
        all_values = record.split(',')
        #image_array = numpy.asfarray(all_values[1:]).reshape((28,28))
        #plt.imshow(image_array,cmap='Greys',interpolation='None')
        #plt.show()
        inputs = (numpy.asfarray(all_values[1:])/255.0 *0.99)+0.01
        targets = numpy.zeros(output_nodes) + 0.01
        targets[int(all_values[0])] = 0.99
        n.train(inputs,targets)
    
        #手写字体倾斜10度作为测试数据
        inputs_plusx_img = scipy.ndimage.interpolation.rotate(inputs.reshape(28,28), 10, cval=0.01, order=1, reshape=False)
        n.train(inputs_plusx_img.reshape(784), targets)
        inputs_minusx_img = scipy.ndimage.interpolation.rotate(inputs.reshape(28,28), -10, cval=0.01, order=1, reshape=False)
        n.train(inputs_minusx_img.reshape(784), targets)
    
    
    test_data_file = open('mnist_test.csv', 'r')
    test_data_list = test_data_file.readlines()
    test_data_file.close()
    # all_values = test_data_list[0].split(',')
    
    # # image_array = numpy.asfarray(all_values[1:]).reshape((28,28))
    # # plt.imshow(image_array,cmap='Greys',interpolation='None')
    # # plt.show()
    
    # output = n.query((numpy.asfarray(all_values[1:])/ 255.0 * 0.99)+0.01)
    
    
    scorecard = []
    for record in test_data_list:
      all_values = record.split(',')
      correct_label = int(all_values[0])
      #print(correct_label,'correct_label')
      inputs = (numpy.asfarray(all_values[1:])/255.0 *0.99)+0.01
      outputs = n.query(inputs)
      label = numpy.argmax(outputs)
      #print(label,'network answer')
      if (label == correct_label):
        scorecard.append(1)
      else:
        scorecard.append(0)
    scorecard_array = numpy.asarray(scorecard)
    print("performance = ",scorecard_array.sum() / scorecard_array.size)
    
    # 识别自己手写字
    our_own_dataset = []
    
    for image_file_name in glob.glob('2828_my_own_?.png'):
      label = int(image_file_name[-5:-4])
      
      print ("loading ... ", image_file_name)
      img_array = imageio.imread(image_file_name, as_gray=True)
      img_data  = 255.0 - img_array.reshape(784)
      
      img_data = (img_data / 255.0 * 0.99) + 0.01
      print(numpy.min(img_data))
      print(numpy.max(img_data))
      
      record = numpy.append(label,img_data)
      our_own_dataset.append(record)
    
    item = 2
    plt.imshow(our_own_dataset[item][1:].reshape(28,28), cmap='Greys', interpolation='None')
    correct_label = our_own_dataset[item][0]
    inputs = our_own_dataset[item][1:]
    
    outputs = n.query(inputs)
    print (outputs)
    
    label = numpy.argmax(outputs)
    print("network says ", label)
    if (label == correct_label):
        print ("match!")
    else:
        print ("no match!")
    
    # 反向生成图像
    label = 0
    targets = numpy.zeros(output_nodes) + 0.01
    targets[label] = 0.99
    print(targets)
    
    image_data = n.backquery(targets)
    
    plt.imshow(image_data.reshape(28,28), cmap='Greys', interpolation='None')
  • 相关阅读:
    Codeforces Round #251 (Div. 2) A
    topcoder SRM 623 DIV2 CatAndRat
    topcoder SRM 623 DIV2 CatchTheBeatEasy
    topcoder SRM 622 DIV2 FibonacciDiv2
    topcoder SRM 622 DIV2 BoxesDiv2
    Leetcode Linked List Cycle II
    leetcode Linked List Cycle
    Leetcode Search Insert Position
    关于vim插件
    Codeforces Round #248 (Div. 2) B. Kuriyama Mirai's Stones
  • 原文地址:https://www.cnblogs.com/Erick-L/p/11785905.html
Copyright © 2011-2022 走看看