zoukankan      html  css  js  c++  java
  • tensorflow下识别手写数字基于MLP网络

      1 # coding: utf-8
      2 
      3 # In[1]:
      4 
      5 import tensorflow as tf
      6 import tensorflow.examples.tutorials.mnist.input_data as input_data
      7 
      8 
      9 # In[2]:
     10 
     11 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
     12 
     13 
     14 # In[3]:
     15 
     16 print('train',mnist.train.num_examples,
     17       ',validation',mnist.validation.num_examples,
     18       ',test',mnist.test.num_examples)
     19 
     20 
     21 # In[4]:
     22 
     23 print('train images     :', mnist.train.images.shape,
     24       'labels:'           , mnist.train.labels.shape)
     25 
     26 
     27 # In[5]:
     28 
     29 import matplotlib.pyplot as plt
     30 def plot_image(image):
     31     plt.imshow(image.reshape(28,28),cmap='binary')
     32     plt.gcf().set_size_inches(2, 4)
     33     plt.show()
     34 
     35 
     36 # In[6]:
     37 
     38 plot_image(mnist.train.images[0])
     39 
     40 
     41 # In[7]:
     42 
     43 import numpy as np
     44 np.argmax(mnist.train.labels[0])
     45 
     46 
     47 # In[8]:
     48 
     49 import matplotlib.pyplot as plt
     50 def plot_images_labels_prediction(images,labels,
     51                                   prediction,idx,num=10):
     52     fig = plt.gcf()
     53     fig.set_size_inches(12, 14)
     54     if num>25: num=25 
     55     for i in range(0, num):
     56         ax=plt.subplot(5,5, 1+i)
     57         
     58         ax.imshow(np.reshape(images[idx],(28, 28)), 
     59                   cmap='binary')
     60             
     61         title= "label=" +str(np.argmax(labels[idx]))
     62         if len(prediction)>0:
     63             title+=",predict="+str(prediction[idx]) 
     64             
     65         ax.set_title(title,fontsize=10) 
     66         ax.set_xticks([]);ax.set_yticks([])        
     67         idx+=1 
     68     plt.show()
     69 
     70 
     71 # In[9]:
     72 
     73 plot_images_labels_prediction(mnist.train.images,
     74                               mnist.train.labels,[],0)
     75 
     76 
     77 # In[10]:
     78 
     79 def layer(output_dim,input_dim,inputs, activation=None):
     80     W = tf.Variable(tf.random_normal([input_dim, output_dim]))
     81     b = tf.Variable(tf.random_normal([1, output_dim]))
     82     XWb = tf.matmul(inputs, W) + b
     83     if activation is None:
     84         outputs = XWb
     85     else:
     86         outputs = activation(XWb)
     87     return outputs
     88 
     89 
     90 # In[11]:
     91 
     92 x = tf.placeholder("float", [None, 784])
     93 h1=layer(output_dim=256,input_dim=784,
     94          inputs=x ,activation=tf.nn.relu)  
     95 y_predict=layer(output_dim=10,input_dim=256,
     96                     inputs=h1,activation=None)
     97 y_label = tf.placeholder("float", [None, 10])
     98 
     99 
    100 # In[12]:
    101 
    102 loss_function = tf.reduce_mean(
    103                   tf.nn.softmax_cross_entropy_with_logits
    104                          (logits=y_predict , 
    105                           labels=y_label))
    106 optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function)
    107 
    108 
    109 # In[13]:
    110 
    111 correct_prediction = tf.equal(tf.argmax(y_label  , 1),
    112                               tf.argmax(y_predict, 1))
    113 accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    114 
    115 
    116 # In[14]:
    117 
    118 trainEpochs = 20
    119 batchSize = 100
    120 totalBatchs = int(mnist.train.num_examples/batchSize)
    121 epoch_list=[];loss_list=[];accuracy_list=[]
    122 from time import time
    123 startTime=time()
    124 sess = tf.Session()
    125 sess.run(tf.global_variables_initializer())
    126 
    127 
    128 # In[15]:
    129 
    130 for epoch in range(trainEpochs):
    131     for i in range(totalBatchs):
    132         batch_x, batch_y = mnist.train.next_batch(batchSize)
    133         sess.run(optimizer,feed_dict={x: batch_x,y_label: batch_y})
    134         
    135     loss,acc = sess.run([loss_function,accuracy],
    136                         feed_dict={x: mnist.validation.images, 
    137                                    y_label: mnist.validation.labels})
    138 
    139     epoch_list.append(epoch);
    140     loss_list.append(loss)
    141     accuracy_list.append(acc)    
    142     print("Train Epoch:", '%02d' % (epoch+1), "Loss=",                 "{:.9f}".format(loss)," Accuracy=",acc)
    143     
    144 duration =time()-startTime
    145 print("Train Finished takes:",duration)      
    146 
    147 
    148 # In[16]:
    149 
    150 get_ipython().magic('matplotlib inline')
    151 import matplotlib.pyplot as plt
    152 fig = plt.gcf()
    153 fig.set_size_inches(4,2)
    154 plt.plot(epoch_list, loss_list, label = 'loss')
    155 plt.ylabel('loss')
    156 plt.xlabel('epoch')
    157 plt.legend(['loss'], loc='upper left')
    158 
    159 
    160 # In[17]:
    161 
    162 plt.plot(epoch_list, accuracy_list,label="accuracy" )
    163 fig = plt.gcf()
    164 fig.set_size_inches(4,2)
    165 plt.ylim(0.8,1)
    166 plt.ylabel('accuracy')
    167 plt.xlabel('epoch')
    168 plt.legend()
    169 plt.show()
    170 
    171 
    172 # In[18]:
    173 
    174 print("Accuracy:", sess.run(accuracy,
    175                            feed_dict={x: mnist.test.images,
    176                                       y_label: mnist.test.labels}))
    177 
    178 
    179 # In[19]:
    180 
    181 prediction_result=sess.run(tf.argmax(y_predict,1),
    182                            feed_dict={x: mnist.test.images })
    183 prediction_result[:10]
    184 
    185 
    186 # In[20]:
    187 
    188 plot_images_labels_prediction(mnist.test.images,
    189                               mnist.test.labels,
    190                               prediction_result,0)
    191 
    192 
    193 # In[21]:
    194 
    195 y_predict_Onehot=sess.run(y_predict,
    196                           feed_dict={x: mnist.test.images })
    197 y_predict_Onehot[8]
    198 
    199 
    200 # In[22]:
    201 
    202 for i in range(400):
    203     if prediction_result[i]!=np.argmax(mnist.test.labels[i]):
    204         print("i="+str(i)+
    205               "   label=",np.argmax(mnist.test.labels[i]),
    206               "predict=",prediction_result[i])
    207 
    208 
    209 # In[ ]:
    View Code

    代码如上。

    手动建立好输入层,隐层,输出层。

    设置损失函数,优化器:

    评估方式与准确率:

    开始分批次训练:

    训练完成后的准确率:

     查看某项中的预测概率:

    筛选出预测失败的数据:

     可以通过:

    tf.summary.merge_all()
    train_writer = tf.summary.FileWriter('log/area',sess.graph)

    保存图。

    通过tensorboard --logdir="路径",打开服务,通过输入localhost:6006之类打开网站。

    查看生成的图:

  • 相关阅读:
    MongoDB 学习笔记(七):主从复制与副本集
    MongoDB 学习笔记(六):备份与用户管理
    MongoDB 学习笔记(五):固定集合、GridFS文件系统与服务器端脚本
    MongoDB 学习笔记(四):索引
    MongoDB 学习笔记(三):分页、排序与游标
    MongoDB 学习笔记(一):安装及简单shell操作
    MongoDB 学习笔记(二):shell中执行增删查改
    mongoDB 入门指南、示例
    mongoDB 介绍(特点、优点、原理)
    企业级任务调度框架Quartz(7) 线程在Quartz里的意义(1)
  • 原文地址:https://www.cnblogs.com/bai2018/p/10472000.html
Copyright © 2011-2022 走看看