zoukankan      html  css  js  c++  java
  • TensorFlow—多层感知器—MNIST手写数字识别


    1
    import tensorflow as tf 2 import tensorflow.examples.tutorials.mnist.input_data as input_data 3 import matplotlib.pyplot as plt 4 import numpy as np 5 mnist=input_data.read_data_sets("MNIST_data/",one_hot=True) #下载据数 6 print('train images:',mnist.train.images.shape, #查看数据 7 'labels:',mnist.train.labels.shape) 8 print('validation images:',mnist.validation.images.shape, 9 'labels:',mnist.validation.labels.shape) 10 print('test images:',mnist.test.images.shape, 11 'labels:',mnist.test.labels.shape 12 #定义显示多项图像的函数 13 def plot_images_labels_prediction_3(images,labels,prediction,idx,num=10): 14 fig=plt.gcf() 15 fig.set_size_inches(12,14) 16 if num>25:num=25 17 for i in range(0,num): 18 ax=plt.subplot(5,5,i+1) 19 ax.imshow(np.reshape(images[idx],(28,28)),cmap='binary') 20 title='lable='+str(np.argmax(labels[idx])) 21 if len(prediction)>0: 22 title+=",prediction="+str(prediction[idx]) 23 ax.set_title(title,fontsize=10) 24 ax.set_xticks([]);ax.set_yticks([]) 25 idx+=1 26 plt.show() 27 28 plot_images_labels_prediction_3(mnist.train.images,mnist.train.labels,[],0) 29 #定义layer函数,构建多层感知器模型 30 def layer(output_dim,input_dim,inputs,activation=None): 31 W=tf.Variable(tf.random_normal([input_dim,output_dim])) 32 b=tf.Variable(tf.random_normal([1,output_dim])) 33 XWb=tf.matmul(inputs,W)+b 34 if activation is None: 35 outputs=XWb 36 else: 37 outputs=activation(XWb) 38 return outputs 39 #建立输入层 40 x=tf.placeholder("float",[None,784]) 41 #建立隐藏层 42 h1=layer(output_dim=256,input_dim=784,inputs=x, 43 activation=tf.nn.relu) 44 #建立输出层 45 y_predict=layer(output_dim=10,input_dim=256,inputs=h1, 46 activation=None) 47 y_label=tf.placeholder("float",[None,10]) 48 #定义损失函数 49 loss_function=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits 50 (logits=y_predict, 51 labels=y_label)) 52 #定义优化器 53 optimizer=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function) 54 #计算每一项数据是否预测正确 55 correct_prediction=tf.equal(tf.argmax(y_label,1), 56 tf.argmax(y_predict,1)) 57 #计算预测正确结果的平均值 58 accuracy=tf.reduce_mean(tf.cast(correct_prediction,"float")) 59 #1、定义训练参数 60 trainEpochs=15 #设置执行15个训练周期 61 batchSize=100 #每一批次项数为100 62 totalBatchs=int(mnist.train.num_examples/batchSize) #计算每个训练周期 63 loss_list=[];epoch_list=[];accuracy_list=[] #初始化训练周期、误差、准确率 64 from time import time #导入时间模块 65 startTime=time() #开始计算时间 66 sess=tf.Session() #建立Session 67 sess.run(tf.global_variables_initializer()) #初始化TensorFlow global 变量 68 #2、进行训练 69 for epoch in range(trainEpochs): 70 for i in range(totalBatchs): 71 batch_x,batch_y=mnist.train.next_batch(batchSize) #使用mnist.train.next_batch方法读取批次数据,传入参数batchSize是100 72 sess.run(optimizer,feed_dict={x:batch_x, 73 y_label:batch_y}) #执行批次训练 74 loss,acc=sess.run([loss_function,accuracy], #使用验证数据计算准确率 75 feed_dict={x:mnist.validation.images, 76 y_label:mnist.validation.labels}) 77 epoch_list.append(epoch); #加入训练周期列表 78 loss_list.append(loss) #加入误差列表 79 accuracy_list.append(acc) #加入准确率列表 80 print("Train Epoch:",'%02d' % (epoch+1),"Loss=", 81 "{:.9f}".format(loss),"Accuracy=",acc) 82 duration=time()-startTime 83 print("Train Finished takes:",duration) #计算并显示全部训练所需时间 84 #画出误差执行结果 85 86 fig=plt.gcf() 87 fig.set_size_inches(4,2) 88 plt.plot(epoch_list,loss_list,label='loss') 89 plt.ylabel('loss') 90 plt.xlabel('epoch') 91 plt.legend(['loss'],loc='upper left') 92 #画出准确率执行结果 93 plt.plot(epoch_list,accuracy_list,label="accuracy") 94 fig=plt.gcf() 95 fig.set_size_inches(4,2) 96 plt.ylim(0.8,1) 97 plt.ylabel('accuracy') 98 plt.xlabel('epoch') 99 plt.legend() 100 plt.show() 101 #评估模型准确率 102 print("accuracy:",sess.run(accuracy, 103 feed_dict={x:mnist.test.images, 104 y_label:mnist.test.labels})) 105 #进行预测 106 #1.执行预测 107 prediction_result=sess.run(tf.argmax(y_predict,1), 108 feed_dict={x:mnist.test.images}) 109 #2.预测结果 110 print(prediction_result[:10]) 111 #3.显示前10项预测结果 112 plot_images_labels_prediction_3(mnist.test.images, 113 mnist.test.labels, 114 prediction_result,0)

    运行结果:

     

    萍水相逢逢萍水,浮萍之水水浮萍!
  • 相关阅读:
    [面试题]去除字符串中相邻两个字符的重复
    [面试题]单向链表的倒序索引值?
    Android数据存储——文件读写操作(File)
    python操作Excel读写(使用xlrd和xlrt)
    在Ubuntu上安装qq2012客户端
    sharepoint 2010开发webpart(转)

    【Sharepoint 2007】WebPart开发、部署过程全记录(转)
    sharepoint2010最初的了解
    基于windows验证的moss2010站点登录域后还弹出对话框解决方法(转)
  • 原文地址:https://www.cnblogs.com/AIBigTruth/p/9852442.html
Copyright © 2011-2022 走看看