zoukankan      html  css  js  c++  java
  • 使用TensorFlow的卷积神经网络识别手写数字(2)-训练篇

      1   
      2 import numpy as np
      3 import tensorflow as tf
      4 import matplotlib
      5 import matplotlib.pyplot as plt
      6 import matplotlib.cm as cm
      7 from tensorflow.examples.tutorials.mnist import input_data
      8 
      9 
     10 # 训练的准确度目标
     11 accuracyGoal = 0.98
     12 
     13 # 是否已经达到指定的准确度
     14 bFlagGoal = False;
     15 
     16 # 显示数字的图像,nBytes为784个点的灰度值,浮点数
     17 def showMnistImg(nBytes):
     18     imgBytes = nBytes.reshape((28, 28))
     19     #print(imgBytes)
     20     plt.figure(figsize=(2.8,2.8))
     21     #plt.grid() #开启网格 
     22     plt.imshow(imgBytes, cmap=cm.gray)
     23     plt.show()
     24     
     25 
     26 #加载mnist数据
     27 mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
     28 
     29 ### 单个手写数字的784个点的灰度值,浮点数,范围[0,1)
     30 ##print('type(mnist.train.images[0]): ', type(mnist.train.images[0]))  # <class 'numpy.ndarray'>
     31 ##print('mnist.train.images.shape: ', mnist.train.images.shape)
     32 ##print(mnist.train.images[0])
     33 ##
     34 ##
     35 ### 单个手写数字的标签
     36 ### 一个one-hot向量除了某一位的数字是1以外其余各维度数字都是0
     37 ### 数字n将表示成一个只有在第n维度(从0开始)数字为1的10维向量。
     38 ##print('type(mnist.train.labels[0]): ', type(mnist.train.labels[0]))# <class 'numpy.ndarray'>
     39 ##print('type(mnist.train.labels.shape): ', type(mnist.train.labels.shape))
     40 ##print(mnist.train.labels[0])
     41 
     42 
     43 
     44 # 下面开始CNN相关
     45 
     46 def conv2d(x, W):
     47   return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
     48 
     49 def max_pool_2x2(x):
     50   return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
     51                         strides=[1, 2, 2, 1], padding='SAME')
     52 
     53 
     54 def weight_variable(shape):
     55   initial = tf.truncated_normal(shape, stddev=0.1)
     56   return tf.Variable(initial)
     57 
     58 def bias_variable(shape):
     59   initial = tf.constant(0.1, shape=shape)
     60   return tf.Variable(initial)
     61 
     62 
     63 x = tf.placeholder(tf.float32, shape=[None, 784])
     64 y_ = tf.placeholder(tf.float32, shape=[None, 10])
     65 
     66 
     67 W_conv1 = weight_variable([5, 5, 1, 32])
     68 b_conv1 = bias_variable([32])
     69 
     70 x_image = tf.reshape(x, [-1, 28, 28, 1])
     71 
     72 h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
     73 h_pool1 = max_pool_2x2(h_conv1)
     74 
     75 
     76 W_conv2 = weight_variable([5, 5, 32, 64])
     77 b_conv2 = bias_variable([64])
     78 
     79 h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
     80 h_pool2 = max_pool_2x2(h_conv2)
     81 
     82 
     83 
     84 W_fc1 = weight_variable([7 * 7 * 64, 1024])
     85 b_fc1 = bias_variable([1024])
     86 
     87 h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
     88 h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
     89 
     90 
     91 keep_prob = tf.placeholder(tf.float32)
     92 h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
     93 
     94 
     95 W_fc2 = weight_variable([1024, 10])
     96 b_fc2 = bias_variable([10])
     97 
     98 y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
     99 
    100 
    101 cross_entropy = tf.reduce_mean(
    102     tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y_conv))
    103 train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    104 correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
    105 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    106 
    107 
    108 
    109 
    110 print('
    开始训练...')
    111 with tf.Session() as sess:
    112   sess.run(tf.global_variables_initializer())
    113   for i in range(3000):
    114     batch = mnist.train.next_batch(50)
    115     
    116     if i % 100 == 0:
    117         train_accuracy = accuracy.eval(feed_dict={ x: batch[0], y_: batch[1], keep_prob: 1.0})
    118         print('次数 %d, 准确度 %g' % (i, train_accuracy))
    119 
    120         if(train_accuracy>accuracyGoal):
    121             #创建saver对象,它添加了一些op用来save和restore模型参数
    122             saver = tf.train.Saver()
    123             #使用saver提供的简便方法去调用 save op
    124             saver.save(sess, "saved_model/cnn_handwrite_number.ckpt")
    125 
    126             print('已保存模型')
    127             bFlagGoal = True
    128             break
    129           
    130     train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
    131 
    132 if(bFlagGoal):
    133     print('训练结束,已达到训练目标')
    134 else:
    135     print('训练结束,未完成训练目标')
    136 
    137 
    138 
    139   
    本文由hATEmATH原创 转载请注明出处:http://www.cnblogs.com/hatemath/
  • 相关阅读:
    hi.baidu.com 百度流量统计
    Autofac is designed to track and dispose of resources for you.
    IIS Manager could not load type for module provider 'SharedConfig' that is declared in administration.config
    How to create and manage configuration backups in Internet Information Services 7.0
    定制swagger的UI
    NSwag在asp.net web api中的使用,基于Global.asax
    NSwag Tutorial: Integrate the NSwag toolchain into your ASP.NET Web API project
    JS变量对象详解
    JS执行上下文(执行环境)详细图解
    JS内存空间详细图解
  • 原文地址:https://www.cnblogs.com/hatemath/p/8513795.html
Copyright © 2011-2022 走看看