zoukankan      html  css  js  c++  java
  • 深度学习(一):Python神经网络——手写数字识别

    声明:本文章为阅读书籍《Python神经网络编程》而来,代码与书中略有差异,书籍封面:

    源码

    若要本地运行,请更改源码中图片与数据集的位置,环境为 Python3.6x.

      1 import numpy as np
      2 import scipy.special as ss
      3 import matplotlib.pyplot as plt
      4 import imageio as im
      5 import glob as gl
      6 
      7 
      8 class NeuralNetwork:
      9     # initialise the network
     10     def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
     11         # set number of each layer
     12         self.inodes = inputnodes
     13         self.hnodes = hiddennodes
     14         self.onodes = outputnodes
     15         self.wih = np.random.normal(0.0, pow(self.inodes, -0.5), (self.hnodes, self.inodes))
     16         self.who = np.random.normal(0.0, pow(self.hnodes, -0.5), (self.onodes, self.hnodes))
     17         # learning rate
     18         self.lr = learningrate
     19         # activation function is sigmoid
     20         self.activation_function = lambda x: ss.expit(x)
     21         pass
     22 
     23     # train the neural network
     24     def train(self, inputs_list, targets_list):
     25         inputs = np.array(inputs_list, ndmin=2).T
     26         targets = np.array(targets_list, ndmin=2).T
     27         hidden_inputs = np.dot(self.wih, inputs)
     28         hidden_outputs = self.activation_function(hidden_inputs)
     29         final_inputs = np.dot(self.who, hidden_outputs)
     30         final_outputs = self.activation_function(final_inputs)
     31         # errors
     32         output_errors = targets - final_outputs
     33         # b-p algorithm
     34         hidden_errors = np.dot(self.who.T, output_errors)
     35         # update weight
     36         self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)),
     37                                      np.transpose(hidden_outputs))
     38         self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)), np.transpose(inputs))
     39         pass
     40 
     41     # query the neural network
     42     def query(self, inputs_list):
     43         inputs = np.array(inputs_list, ndmin=2).T
     44         hidden_inputs = np.dot(self.wih, inputs)
     45         hidden_outputs = self.activation_function(hidden_inputs)
     46         final_inputs = np.dot(self.who, hidden_outputs)
     47         final_outputs = self.activation_function(final_inputs)
     48         return final_outputs
     49 
     50     # numbers
     51 
     52 
     53 input_nodes = 784
     54 hidden_nodes = 100
     55 output_nodes = 10
     56 
     57 # learning rate
     58 learning_rate = 0.2
     59 
     60 # creat instance of neural network
     61 global n
     62 n = neuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)
     63 
     64 # file read only ,root of the file
     65 training_data_file = open(r"C:UsersELIODesktopmnist_train.txt", 'r')
     66 training_data_list = training_data_file.readlines()
     67 training_data_file.close()
     68 
     69 # train the neural network
     70 epochs = 5
     71 for e in range(epochs):
     72     for record in training_data_list:
     73         all_values = record.split(',')
     74         # scale and shift the inputs
     75         inputs = (np.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
     76         targets = np.zeros(output_nodes) + 0.01
     77         # all_values[0] is the target label for this record
     78         targets[int(all_values[0])] = 0.99
     79         n.train(inputs, targets)
     80     pass
     81 pass
     82 
     83 # load the file into a list
     84 test_data_file = open(r"C:UsersELIODesktopmnist_train_100.csv.txt", 'r')
     85 test_data_list = test_data_file.readlines()
     86 test_data_file.close()
     87 
     88 # test the neural network
     89 # score for how well the network performs
     90 score = []
     91 
     92 # go through all the records
     93 for record in test_data_list:
     94     all_values = record.split(',')
     95     # correct answer is the first value
     96     correct_label = int(all_values[0])
     97     # scale and shift the inputs
     98     inputs = (np.asfarray(all_values[1:]) / 255.0 * 0.99) + 0.01
     99     # query the network
    100     outputs = n.query(inputs)
    101     # the index of the highest value corresponds to the label
    102     label = np.argmax(outputs)
    103     # append correct or incorrect to list
    104     if (label == correct_label):
    105         score.append(1)
    106     else:
    107         score.append(0)
    108         pass
    109 pass
    110 # module1 CORRECT-RATE
    111 # calculate the score, the fraction of correct answers
    112 score_array = np.asarray(score)
    113 print("performance = ", score_array.sum() / score_array.size)
    114 
    115 # module2 TEST MNIST
    116 all_values = test_data_list[0].split(',')
    117 print(all_values[0])
    118 image_array = np.asfarray(all_values[1:]).reshape((28, 28))
    119 plt.imshow(image_array, cmap='Greys', interpolation='None')
    120 plt.show()
    121 
    122 # module3 USE YOUR WRITING
    123 # own image test data set
    124 own_dataset = []
    125 for image_file_name in gl.gl(r'C:UsersELIODesktop5.png'):
    126     print("loading ... ", image_file_name)
    127     # use the filename to set the label
    128     label = int(image_file_name[-5:-4])
    129     # load image data from png files into an array
    130     img_array = im.imread(image_file_name, as_gray=True)
    131     # reshape from 28x28 to list of 784 values, invert values
    132     img_data = 255.0 - img_array.reshape(784)
    133     # then scale data to range from 0.01 to 1.0
    134     img_data = (img_data / 255.0 * 0.99) + 0.01
    135     print(np.min(img_data))
    136     print(np.max(img_data))
    137     # append label and image data  to test data set
    138     record = np.append(label, img_data)
    139     print(record)
    140     own_dataset.append(record)
    141     pass
    142 all_values = own_dataset[0]
    143 print(all_values[0])

     数据集,实验图片

    链接:百度网盘 
    提取码:1vbq

  • 相关阅读:
    获取本机外网ip和内网ip
    服务器发布MVC常见问题解决方案
    Ext.Net学习笔记01:在ASP.NET WebForm中使用Ext.Net
    Form验证(转)
    各浏览器各版本User-agent汇总 欢迎补充
    MSSQL中把表中的数据导出成Insert
    发布mvc3的项目时system.web.mvc 版本 为3.0.0.1高于服务器版本3.0.0.0 升级到3.0.0.1
    MySQL Packets larger than max_allowed_packet are not allowed
    SQL查看数据库所用用户表数量和使用的空间
    公用提示对话框
  • 原文地址:https://www.cnblogs.com/oasisyang/p/13199480.html
Copyright © 2011-2022 走看看