zoukankan      html  css  js  c++  java
  • Tensorflow ——神经网络

    Training Data Eval:
    Num examples: 55000 Num correct: 52015 Precision @ 1: 0.9457
    Validation Data Eval:
    Num examples: 5000 Num correct: 4740 Precision @ 1: 0.9480
    Test Data Eval:
    Num examples: 10000 Num correct: 9456 Precision @ 1: 0.9456

      1 import tensorflow as tf
      2 import input_data
      3 import math
      4 
      5 NUM_CLASSES = 10
      6 IMAGE_SIZE = 28
      7 IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE
      8 flags = tf.app.flags
      9 FLAGS = flags.FLAGS
     10 flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
     11 flags.DEFINE_integer('max_steps', 10000, 'Number of steps to run trainer.')
     12 flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
     13 flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
     14 flags.DEFINE_integer('batch_size', 100, 'Batch size.  '
     15                      'Must divide evenly into the dataset sizes.')
     16 flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
     17 flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
     18                      'for unit testing.')
     19 
     20 def inference(images, hidden1_units, hidden2_units):
     21   with tf.name_scope('hidden1'):
     22     weights = tf.Variable(
     23         tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
     24                             stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),
     25         name='weights')
     26     biases = tf.Variable(tf.zeros([hidden1_units]),
     27                          name='biases')
     28     hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
     29   with tf.name_scope('hidden2'):
     30     weights = tf.Variable(
     31         tf.truncated_normal([hidden1_units, hidden2_units],
     32                             stddev=1.0 / math.sqrt(float(hidden1_units))),
     33         name='weights')
     34     biases = tf.Variable(tf.zeros([hidden2_units]),
     35                          name='biases')
     36     hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
     37   with tf.name_scope('softmax_linear'):
     38     weights = tf.Variable(
     39         tf.truncated_normal([hidden2_units, NUM_CLASSES],
     40                             stddev=1.0 / math.sqrt(float(hidden2_units))),
     41         name='weights')
     42     biases = tf.Variable(tf.zeros([NUM_CLASSES]),
     43                          name='biases')
     44     logits = tf.matmul(hidden2, weights) + biases
     45   return logits
     46 
     47 def loss(logits, labels):
     48   labels = tf.to_int64(labels)
     49   cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
     50       logits, labels, name='xentropy')
     51   loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')
     52   return loss
     53 
     54 def training(loss, learning_rate):
     55   tf.scalar_summary(loss.op.name, loss)
     56   optimizer = tf.train.GradientDescentOptimizer(learning_rate)
     57   global_step = tf.Variable(0, name='global_step', trainable=False)
     58   train_op = optimizer.minimize(loss, global_step=global_step)
     59   return train_op
     60 
     61 def evaluation(logits, labels):
     62   correct = tf.nn.in_top_k(logits, labels, 1)
     63   return tf.reduce_sum(tf.cast(correct, tf.int32))
     64 
     65 def placeholder_inputs(batch_size):
     66   images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
     67                                                          IMAGE_PIXELS))
     68   labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
     69   return images_placeholder, labels_placeholder
     70 
     71 
     72 def fill_feed_dict(data_set, images_pl, labels_pl):
     73   images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
     74                                                  FLAGS.fake_data)
     75   feed_dict = {
     76       images_pl: images_feed,
     77       labels_pl: labels_feed,
     78   }
     79   return feed_dict
     80 
     81 
     82 def do_eval(sess,
     83             eval_correct,
     84             images_placeholder,
     85             labels_placeholder,
     86             data_set):
     87   true_count = 0
     88   steps_per_epoch = data_set.num_examples // FLAGS.batch_size
     89   num_examples = steps_per_epoch * FLAGS.batch_size
     90   for step in range(steps_per_epoch):
     91     feed_dict = fill_feed_dict(data_set,
     92                                images_placeholder,
     93                                labels_placeholder)
     94     true_count += sess.run(eval_correct, feed_dict=feed_dict)
     95   precision = true_count / num_examples
     96   print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
     97         (num_examples, true_count, precision))
     98 
     99 def run_training():
    100   data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
    101   print(FLAGS.train_dir, FLAGS.fake_data)
    102   with tf.Graph().as_default():
    103     images_placeholder, labels_placeholder = placeholder_inputs(
    104         FLAGS.batch_size)
    105     logits = inference(images_placeholder,
    106                              FLAGS.hidden1,
    107                              FLAGS.hidden2)
    108     loss_minist = loss(logits, labels_placeholder)
    109     train_op = training(loss_minist, FLAGS.learning_rate)
    110     eval_correct = evaluation(logits, labels_placeholder)
    111     summary = tf.merge_all_summaries()
    112     init = tf.initialize_all_variables()
    113     sess = tf.Session()
    114     summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)
    115     sess.run(init)
    116     for step in range(FLAGS.max_steps):
    117       feed_dict = fill_feed_dict(data_sets.train,
    118                                  images_placeholder,
    119                                  labels_placeholder)
    120       _, loss_value = sess.run([train_op, loss_minist],
    121                                feed_dict=feed_dict)
    122 
    123       if step % 100 == 0:
    124         print('Step %d: loss = %.2f' % (step, loss_value))
    125         summary_str = sess.run(summary, feed_dict=feed_dict)
    126         summary_writer.add_summary(summary_str, step)
    127         summary_writer.flush()
    128       if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
    129         print('Training Data Eval:')
    130         do_eval(sess,
    131                 eval_correct,
    132                 images_placeholder,
    133                 labels_placeholder,
    134                 data_sets.train)
    135         print('Validation Data Eval:')
    136         do_eval(sess,
    137                 eval_correct,
    138                 images_placeholder,
    139                 labels_placeholder,
    140                 data_sets.validation)
    141         print('Test Data Eval:')
    142         do_eval(sess,
    143                 eval_correct,
    144                 images_placeholder,
    145                 labels_placeholder,
    146                 data_sets.test)
    147 run_training()
  • 相关阅读:
    如何制作URL文件
    对象映射工具AutoMapper介绍
    C#高阶函数介绍
    System.Web.Caching.Cache
    系统架构设计:进程缓存和缓存服务,如何抉择?
    System.Web.Caching.Cache类 缓存 各种缓存依赖
    max server memory (MB)最大服务器内存配置--缓解内存压力
    第0节:.Net版基于WebSocket的聊天室样例
    第六节:Core SignalR中的重连机制和心跳监测机制详解
    第五节:SignalR完结篇之依赖注入和分布式部署
  • 原文地址:https://www.cnblogs.com/qw12/p/6139446.html
Copyright © 2011-2022 走看看