zoukankan      html  css  js  c++  java
  • fully_connected_feed源码分析

    基础知识:

    argparse:  用于在命令行运行python时,输入参数

    例如proxy.py:

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("square", help="display a square of a given number",
                        type=int)
    args = parser.parse_args()
    print args.square**2

    命令行下:

    $ python prog.py 4
    16
    $ python prog.py four
    usage: prog.py [-h] square
    prog.py: error: argument square: invalid int value: 'four'

    解释:

    parser.add_argument("square", help="display a square of a given number",
                        type=int)

    # 插入参数square,并且要求该参数为int类型
    # $ python prog.py 4

    # 16

     os和sys:  os模块负责程序与操作系统的交互,提供了访问操作系统底层的接口;sys模块负责程序与python解释器的交互,提供了一系列的函数和变量,用于操控python的运行时环境。

    <os 常用方法>
    os.remove(‘path/filename’) 删除文件
    os.rename(oldname, newname) 重命名文件
    os.walk() 生成目录树下的所有文件名
    os.chdir('dirname') 改变目录
    os.mkdir/makedirs('dirname')创建目录/多层目录
    os.rmdir/removedirs('dirname') 删除目录/多层目录
    os.listdir('dirname') 列出指定目录的文件
    os.getcwd() 取得当前工作目录
    os.chmod() 改变目录权限
    os.path.basename(‘path/filename’) 去掉目录路径,返回文件名
    os.path.dirname(‘path/filename’) 去掉文件名,返回目录路径
    os.path.join(path1[,path2[,...]]) 将分离的各部分组合成一个路径名
    os.path.split('path') 返回( dirname(), basename())元组
    os.path.splitext() 返回 (filename, extension) 元组
    os.path.getatimectimemtime 分别返回最近访问、创建、修改时间
    os.path.getsize() 返回文件大小
    os.path.exists() 是否存在
    os.path.isabs() 是否为绝对路径
    os.path.isdir() 是否为目录
    os.path.isfile() 是否为文件
    <sys 常用方法>
    sys.argv 命令行参数List,第一个元素是程序本身路径
    sys.modules.keys() 返回所有已经导入的模块列表
    sys.exc_info() 获取当前正在处理的异常类,exc_type、exc_value、exc_traceback当前处理的异常详细信息
    sys.exit(n) 退出程序,正常退出时exit(0)
    sys.hexversion 获取Python解释程序的版本值,16进制格式如:0x020403F0
    sys.version 获取Python解释程序的版本信息
    sys.maxint 最大的Int值
    sys.maxunicode 最大的Unicode值
    sys.modules 返回系统导入的模块字段,key是模块名,value是模块
    sys.path 返回模块的搜索路径,初始化时使用PYTHONPATH环境变量的值
    sys.platform 返回操作系统平台名称
    sys.stdout 标准输出
    sys.stdin 标准输入
    sys.stderr 错误输出
    sys.exc_clear() 用来清除当前线程所出现的当前的或最近的错误信息
    sys.exec_prefix 返回平台独立的python文件安装的位置
    sys.byteorder 本地字节规则的指示器,big-endian平台的值是'big',little-endian平台的值是'little'
    sys.copyright 记录python版权相关的东西
    sys.api_version 解释器的C的API版本
    sys.stdin,sys.stdout,sys.stderr

    fully_connected_feed源码分析

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    from tensorflow.examples.tutorials.mnist import mnist
    from six.moves import xrange
    
    # pylint: disable=missing-docstring
    import argparse
    import os
    import sys
    import time
    
    
    # Basic model parameters as external flags.
    FLAGS = None
    
    
    def placeholder_inputs(batch_size):
      """
        定义两个placeholder,同时自定义placeholder的shape
      Args:
        batch_size: 分别定义images_placeholder和labels_placeholder的shape
    
      Returns:
        images_placeholder: Images placeholder.
        labels_placeholder: Labels placeholder.
      """
      # Note that the shapes of the placeholders match the shapes of the full
      # image and label tensors, except the first dimension is now batch_size
      # rather than the full size of the train or test data sets.
      images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
                                                             mnist.IMAGE_PIXELS))
      labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
      return images_placeholder, labels_placeholder
    
    def fill_feed_dict(data_set, images_pl, labels_pl):
      """
      将data_set放入images_pl和labels_pl的,然后再放入feed_dict中用于动态训练
    
      Args:
        data_set: 来自于input_data.read_data_sets()的数据集
        images_pl: 来自于laceholder_inputs()的images placeholder
        labels_pl: 来自于laceholder_inputs()的labels placeholder
    
      Returns:
        feed_dict: 向feed_dict中传入值images_feed和labels_feed
      """
      # Create the feed_dict for the placeholders filled with the next
      # `batch size` examples.
      images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
                                                     FLAGS.fake_data)
      feed_dict = {
          images_pl: images_feed,
          labels_pl: labels_feed,
      }
      return feed_dict
    
    
    def do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_set):
      """Runs one evaluation against the full epoch of data.
    
      Args:
        sess: The session in which the model has been trained.
        eval_correct: The Tensor that returns the number of correct predictions.
        images_placeholder: The images placeholder.
        labels_placeholder: The labels placeholder.
        data_set: The set of images and labels to evaluate, from
          input_data.read_data_sets().
      """
      # And run one epoch of eval.
      true_count = 0  # Counts the number of correct predictions.
      steps_per_epoch = data_set.num_examples // FLAGS.batch_size
      num_examples = steps_per_epoch * FLAGS.batch_size
      for step in xrange(steps_per_epoch):
        feed_dict = fill_feed_dict(data_set,
                                   images_placeholder,
                                   labels_placeholder)
        true_count += sess.run(eval_correct, feed_dict=feed_dict)
      precision = float(true_count) / num_examples
      print('Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
            (num_examples, true_count, precision))
    
    
    def run_training():
      """Train MNIST for a number of steps."""
      # Get the sets of images and labels for training, validation, and
      # test on MNIST.
      data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)
    
      # Tell TensorFlow that the model will be built into the default Graph.
      with tf.Graph().as_default():
        # Generate placeholders for the images and labels.
        images_placeholder, labels_placeholder = placeholder_inputs(
            FLAGS.batch_size)
    
        # Build a Graph that computes predictions from the inference model.
        logits = mnist.inference(images_placeholder,
                                 FLAGS.hidden1,
                                 FLAGS.hidden2)
    
        # Add to the Graph the Ops for loss calculation.
        loss = mnist.loss(logits, labels_placeholder)
    
        # Add to the Graph the Ops that calculate and apply gradients.
        train_op = mnist.training(loss, FLAGS.learning_rate)
    
        # Add the Op to compare the logits to the labels during evaluation.
        eval_correct = mnist.evaluation(logits, labels_placeholder)
    
        # Build the summary Tensor based on the TF collection of Summaries.
        summary = tf.summary.merge_all()
    
        # Add the variable initializer Op.
        init = tf.global_variables_initializer()
    
        # Create a saver for writing training checkpoints.
        saver = tf.train.Saver()
    
        # Create a session for running Ops on the Graph.
        sess = tf.Session()
    
        # Instantiate a SummaryWriter to output summaries and the Graph.
        summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
    
        # And then after everything is built:
    
        # Run the Op to initialize the variables.
        sess.run(init)
    
        # Start the training loop.
        for step in xrange(FLAGS.max_steps):
          start_time = time.time()
    
          # Fill a feed dictionary with the actual set of images and labels
          # for this particular training step.
          feed_dict = fill_feed_dict(data_sets.train,
                                     images_placeholder,
                                     labels_placeholder)
    
          # Run one step of the model.  The return values are the activations
          # from the `train_op` (which is discarded) and the `loss` Op.  To
          # inspect the values of your Ops or variables, you may include them
          # in the list passed to sess.run() and the value tensors will be
          # returned in the tuple from the call.
          _, loss_value = sess.run([train_op, loss],
                                   feed_dict=feed_dict)
    
          duration = time.time() - start_time
    
          # Write the summaries and print an overview fairly often.
          if step % 100 == 0:
            # Print status to stdout.
            print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
            # Update the events file.
            summary_str = sess.run(summary, feed_dict=feed_dict)
            summary_writer.add_summary(summary_str, step)
            summary_writer.flush()
    
          # Save a checkpoint and evaluate the model periodically.
          if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
            checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
            saver.save(sess, checkpoint_file, global_step=step)
            # Evaluate against the training set.
            print('Training Data Eval:')
            do_eval(sess,
                    eval_correct,
                    images_placeholder,
                    labels_placeholder,
                    data_sets.train)
            # Evaluate against the validation set.
            print('Validation Data Eval:')
            do_eval(sess,
                    eval_correct,
                    images_placeholder,
                    labels_placeholder,
                    data_sets.validation)
            # Evaluate against the test set.
            print('Test Data Eval:')
            do_eval(sess,
                    eval_correct,
                    images_placeholder,
                    labels_placeholder,
                    data_sets.test)
    
    
    def main(_):
      if tf.gfile.Exists(FLAGS.log_dir):
        tf.gfile.DeleteRecursively(FLAGS.log_dir)
      tf.gfile.MakeDirs(FLAGS.log_dir)
      run_training()
    
    
    if __name__ == '__main__':
      parser = argparse.ArgumentParser()
      parser.add_argument(
          '--learning_rate',
          type=float,
          default=0.01,
          help='Initial learning rate.'
      )
      parser.add_argument(
          '--max_steps',
          type=int,
          default=2000,
          help='Number of steps to run trainer.'
      )
      parser.add_argument(
          '--hidden1',
          type=int,
          default=128,
          help='Number of units in hidden layer 1.'
      )
      parser.add_argument(
          '--hidden2',
          type=int,
          default=32,
          help='Number of units in hidden layer 2.'
      )
      parser.add_argument(
          '--batch_size',
          type=int,
          default=100,
          help='Batch size.  Must divide evenly into the dataset sizes.'
      )
      parser.add_argument(
          '--input_data_dir',
          type=str,
          default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
                               'tensorflow/mnist/input_data'),
          help='Directory to put the input data.'
      )
      parser.add_argument(
          '--log_dir',
          type=str,
          default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
                               'tensorflow/mnist/logs/fully_connected_feed'),
          help='Directory to put the log data.'
      )
      parser.add_argument(
          '--fake_data',
          default=False,
          help='If true, uses fake data for unit testing.',
          action='store_true'
      )
    
      FLAGS, unparsed = parser.parse_known_args()
      tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
  • 相关阅读:
    第 12 章 Docker Swarm
    第 1 章 虚拟化
    第 0 章 写在最前面
    第 11 章 日志管理
    第 11 章 日志管理
    第 11 章 日志管理
    第 11 章 日志管理
    第 11 章 日志管理
    第 11 章 日志管理
    第 11 章 日志管理
  • 原文地址:https://www.cnblogs.com/smartmsl/p/10914869.html
Copyright © 2011-2022 走看看