zoukankan      html  css  js  c++  java
  • TensorFlow-mnist

    训练代码:

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    
    flags = tf.app.flags
    FLAGS = flags.FLAGS
    flags.DEFINE_string('data_dir', '/tmp/data/', 'Directory for storing data')
    
    print(FLAGS.data_dir)
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
    
    input=tf.placeholder(tf.float32,[None,784],name='input')
    label=tf.placeholder(tf.float32,[None,10],name='label')
    keep_prob=tf.placeholder(tf.float32,name='keep_prob')
    
    image=tf.reshape(input,[-1,28,28,1])
    
    conv1_W=tf.Variable(tf.truncated_normal([5,5,1,32],stddev=0.1))
    conv1_b=tf.Variable(tf.constant(0.1,shape=[32]))
    layer1=tf.nn.elu(tf.nn.conv2d(image,conv1_W,strides=[1,1,1,1],padding='SAME')+conv1_b)
    layer2=tf.nn.max_pool(layer1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    
    conv2_W=tf.Variable(tf.truncated_normal([5,5,32,64],stddev=0.1))
    conv2_b=tf.Variable(tf.constant(0.1,shape=[64]))
    layer3=tf.nn.elu(tf.nn.conv2d(layer2,conv2_W,strides=[1,1,1,1],padding='SAME')+conv2_b)
    layer4=tf.nn.max_pool(layer3,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
    
    layer5=tf.reshape(layer4,[-1,7*7*64])
    
    fc1_W=tf.Variable(tf.truncated_normal([7*7*64,1024],stddev=0.1))
    fc1_b=tf.Variable(tf.constant(0.1,shape=[1024]))
    layer5=tf.reshape(layer4,[-1,7*7*64])
    layer6=tf.nn.elu(tf.matmul(layer5,fc1_W)+fc1_b)
    
    layer7=tf.nn.dropout(layer6,keep_prob)
    
    fc2_W=tf.Variable(tf.truncated_normal([1024,10],stddev=0.1))
    fc2_b=tf.Variable(tf.constant(0.1,shape=[10]))
    output=tf.nn.softmax(tf.matmul(layer7,fc2_W)+fc2_b,name='output')
    
    cross_entropy=tf.reduce_mean(-tf.reduce_sum(label*tf.log(output),reduction_indices=[1]))
    
    train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
    correct_predition=tf.equal(tf.argmax(output,1),tf.arg_max(label,1))
    accuracy=tf.reduce_mean(tf.cast(correct_predition,tf.float32),name='accuracy')
    
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    batch = mnist.train.next_batch(50)
    for i in range(20000):
        batch = mnist.train.next_batch(50)
        train_step.run(feed_dict={input: batch[0], label: batch[1], keep_prob: 0.5})
        if i%100==0:
            train_accuracy = accuracy.eval(feed_dict={input:batch[0], label:batch[1], keep_prob: 1.0})
            print("%d:training accuracy %g"%(i,train_accuracy))
    
    saver = tf.train.Saver()
    save_path = saver.save(sess,"E:/dnn/model")

    测试代码:

    from __future__ import division
    import numpy as np
    import tensorflow as tf
    from PIL import Image
    
    img = Image.open('E:/dnn/test.bmp').convert('L')
    if img.size[0] != 28 or img.size[1] != 28:
        img = img.resize((28, 28))
    arr = []
    for i in range(28):
        for j in range(28):
            pixel = 1.0 - float(img.getpixel((j, i)))/255.0
            arr.append(pixel)
    image = np.array(arr).reshape((1, 28, 28, 1))
    
    saver = tf.train.import_meta_graph('E:/dnn/model.meta')
    graph = tf.get_default_graph()
    input=graph.get_tensor_by_name('input:0')
    label=graph.get_tensor_by_name('label:0')
    output=graph.get_tensor_by_name('output:0')
    keep_prob=graph.get_tensor_by_name('keep_prob:0')
    accuracy=graph.get_tensor_by_name('accuracy:0')
    
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint('E:/dnn'))
        test = sess.run(output, feed_dict={input: image.reshape(-1,784), label: np.full(10,1e-10).reshape(-1,10), keep_prob: 1.0})
        print(test)
        ans=0
        for i in range(10):
            if (test[0][i]>test[0][ans]):
                ans=i
        print(ans)

    测试结果:

  • 相关阅读:
    C语言单链表创建,插入,删除
    Java乔晓松spring构造函数的注入以及null的注入
    sentilib_语料库项目_search模块的实现
    spring入门(6)set方法注入依赖之null的注入
    Java乔晓松使用Filter过滤器清除网页缓存
    漂亮的弹框
    C#判断各种字符串(如手机号)
    视频数字水印
    数据校验
    SVN 常见问题操作总结
  • 原文地址:https://www.cnblogs.com/dramstadt/p/7453807.html
Copyright © 2011-2022 走看看