zoukankan      html  css  js  c++  java
  • tensorflow--mnist注解

    我自己对mnist官方例程进行了部分注解,希望分享出来有助于入门选手更好理解tensorflow的运行机制,可以拷贝到IDE再调试看看,看看具体数据流向还有一部分tensorflow里面用到的库。
    我用的是pip安装的tensorflow-GPU-1.13,这段源码原始位置在https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py

    代码:

      1 from __future__ import absolute_import
      2 from __future__ import division
      3 from __future__ import print_function
      4 
      5 #absl是python标准库内的
      6 from absl import app as absl_app
      7 from absl import flags
      8 
      9 import tensorflow as tf  # pylint: disable=g-bad-import-order
     10 
     11 from official.mnist import dataset
     12 from official.utils.flags import core as flags_core
     13 from official.utils.logs import hooks_helper
     14 from official.utils.misc import distribution_utils
     15 from official.utils.misc import model_helpers
     16 
     17 
     18 LEARNING_RATE = 1e-4
     19 
     20 #参数默认data_format = 'channels_first'
     21 def create_model(data_format):
     22   """Model to recognize digits in the MNIST dataset.
     23 
     24   Network structure is equivalent to:
     25   https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
     26   and
     27   https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
     28 
     29   But uses the tf.keras API.
     30 
     31   Args:
     32     data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is
     33       typically faster on GPUs while 'channels_last' is typically faster on
     34       CPUs. See
     35       https://www.tensorflow.org/performance/performance_guide#data_formats
     36 
     37   Returns:
     38     A tf.keras.Model.
     39   """
     40 
     41   #data_format:一个字符串,可以是channels_last(默认)或channels_first,
     42   # 表示输入中维度的顺序,channels_last对应于具有形状(batch, height, width, channels)
     43   # 的输入,而channels_first对应于具有形状(batch, channels, height, width)的输入.
     44   #这里感觉输入只有三个维度,默认是单通道图?
     45   if data_format == 'channels_first':
     46     input_shape = [1, 28, 28]
     47   else:
     48     assert data_format == 'channels_last'
     49     input_shape = [28, 28, 1]
     50 
     51   #将tf.keras.layers.MaxPooling2D传递给max_pool
     52   l = tf.keras.layers
     53   max_pool = l.MaxPooling2D(
     54       (2, 2), (2, 2), padding='same', data_format=data_format)
     55   # The model consists of a sequential chain of layers, so tf.keras.Sequential
     56   # (a subclass of tf.keras.Model) makes for a compact description.
     57   return tf.keras.Sequential(
     58       [
     59           #输入层确保输入的大小符合网络需要[28, 28]->[1, 28, 28]
     60           l.Reshape(
     61               target_shape=input_shape,
     62               input_shape=(28 * 28,)),
     63           #卷积
     64           l.Conv2D(
     65               32,#filters:整数, 输出空间的维数(即卷积中的滤波器数),就是卷积核个数
     66               5,#卷积核大小,这里是5x5
     67               padding='same',
     68               data_format=data_format,
     69               activation=tf.nn.relu),
     70           #最大pooling
     71           max_pool,
     72           #卷积
     73           l.Conv2D(
     74               64,
     75               5,
     76               padding='same',
     77               data_format=data_format,
     78               activation=tf.nn.relu),
     79           # 最大pooling
     80           max_pool,
     81           #在保留第0轴的情况下对输入的张量进行Flatten(扁平化),拉直?
     82           l.Flatten(),
     83           #fc 1024 -> units: 该层的神经单元结点数。
     84           l.Dense(1024, activation=tf.nn.relu),
     85           l.Dropout(0.4),
     86           #fc输出
     87           l.Dense(10)
     88       ])
     89 
     90 #添加了很多参数,指定了一部分的值,数据url,模型url,batch_size等等
     91 def define_mnist_flags():
     92   flags_core.define_base()
     93   flags_core.define_performance(num_parallel_calls=False)
     94   flags_core.define_image()
     95   flags.adopt_module_key_flags(flags_core)
     96   #自定义项参数都在这里设置了
     97   flags_core.set_defaults(data_dir='./tmp/mnist_data',
     98                           model_dir='./tmp/mnist_model',
     99                           batch_size=100,
    100                           train_epochs=40,
    101                           stop_threshold=0.998)
    102 
    103 
    104 def model_fn(features, labels, mode, params):
    105   """The model_fn argument for creating an Estimator."""
    106   # 翻译成中文,注释的意思就是添加一个data_format的参数,下面的Estimator类需要用到
    107   model = create_model(params['data_format'])
    108   image = features
    109   # 来判断一个对象是否是一个已知的类型。
    110   if isinstance(image, dict):
    111     image = features['image']
    112 
    113   #测试模式
    114   if mode == tf.estimator.ModeKeys.PREDICT:
    115     logits = model(image, training=False)
    116     predictions = {
    117         'classes': tf.argmax(logits, axis=1),
    118         'probabilities': tf.nn.softmax(logits),
    119     }
    120     #如果只是测试到这里就返回了
    121     return tf.estimator.EstimatorSpec(
    122         mode=tf.estimator.ModeKeys.PREDICT,
    123         predictions=predictions,
    124         export_outputs={
    125             'classify': tf.estimator.export.PredictOutput(predictions)
    126         })
    127 
    128   #训练模式
    129   if mode == tf.estimator.ModeKeys.TRAIN:
    130     #设置LEARNING_RATE
    131     optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
    132 
    133     logits = model(image, training=True)
    134     loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    135     accuracy = tf.metrics.accuracy(
    136       labels=labels, predictions=tf.argmax(logits, axis=1))
    137 
    138     # Name tensors to be logged with LoggingTensorHook.
    139     tf.identity(LEARNING_RATE, 'learning_rate')
    140     tf.identity(loss, 'cross_entropy')
    141     tf.identity(accuracy[1], name='train_accuracy')
    142 
    143     # Save accuracy scalar to Tensorboard output.
    144     tf.summary.scalar('train_accuracy', accuracy[1])
    145 
    146     return tf.estimator.EstimatorSpec(
    147         mode=tf.estimator.ModeKeys.TRAIN,
    148         loss=loss,
    149         train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step()))
    150   if mode == tf.estimator.ModeKeys.EVAL:
    151     logits = model(image, training=False)
    152     loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    153     return tf.estimator.EstimatorSpec(
    154         mode=tf.estimator.ModeKeys.EVAL,
    155         loss=loss,
    156         eval_metric_ops={
    157             'accuracy':
    158                 tf.metrics.accuracy(
    159                     labels=labels, predictions=tf.argmax(logits, axis=1)),
    160         })
    161 
    162 
    163 def run_mnist(flags_obj):
    164   """Run MNIST training and eval loop.
    165 
    166   Args:
    167     flags_obj: An object containing parsed flag values.
    168   """
    169 
    170   #apply_clean是官方例程里面提供的用来清理现存model的方法,
    171   # 取决于flags_obj.clean(True则清理flags_obj.model_dir内的文件)
    172   model_helpers.apply_clean(flags_obj)
    173 
    174   #把自定义的实现传给tf.estimator.Estimator
    175   model_function = model_fn
    176 
    177   #tf.ConfigProto()主要的作用是配置tf.Session的运算方式,比如gpu运算或者cpu运算
    178   session_config = tf.ConfigProto(
    179       #设置线程一个操作内部并行运算的线程数,比如矩阵乘法,如果设置为0,则表示以最优的线程数处理
    180       inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
    181       #设置多个操作并行运算的线程数,比如 c = a + b,d = e + f . 可以并行运算
    182       intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
    183       #有时候,不同的设备,它的cpu和gpu是不同的,如果将这个选项设置成True,
    184       # 那么当运行设备不满足要求时,会自动分配GPU或者CPU
    185       allow_soft_placement=True)
    186 
    187   #获取gpu数目,优化算法等,用于优化
    188   distribution_strategy = distribution_utils.get_distribution_strategy(
    189       flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)
    190 
    191   #所有输出(检查点,事件文件等)都被写入model_dir或其子目录.如果model_dir未设置,则使用临时目录.
    192   #可以通过RunConfig对象(包含了有关执行环境的信息)传递config参数.它被传递给model_fn,
    193   # 如果model_fn有一个名为“config”的参数(和输入函数以相同的方式).如果该config参数未被传递,
    194   # 则由Estimator进行实例化.不传递配置意味着使用对本地执行有用的默认值.Estimator使配置对模型
    195   # 可用(例如,允许根据可用的工作人员数量进行专业化),并且还使用其一些字段来控制内部,特别是关于检查点
    196   run_config = tf.estimator.RunConfig(
    197       train_distribute=distribution_strategy, session_config=session_config)
    198 
    199   data_format = flags_obj.data_format
    200   #channels_first,即(3,128,128,128)通道数在最前面
    201   #channels_last,即(128,128,128,3)通道数在最后面
    202   if data_format is None:
    203     data_format = ('channels_first'
    204                    if tf.test.is_built_with_cuda() else 'channels_last')#判断安装的TF是否支持GPU
    205 
    206   #estimator类对TensorFlow模型进行训练和计算.
    207   #Estimator对象包装由model_fn指定的模型,其中,给定输入和其他一些参数,返回需要进行训练、计算,或预测的操作.
    208   mnist_classifier = tf.estimator.Estimator(
    209       #这个model_fn是参数名而已
    210       model_fn=model_function,#模型对象
    211       model_dir=flags_obj.model_dir,#模型目录,如果为空会创建一个临时目录
    212       #猜测会去model_dir中寻找数据
    213       config=run_config,#运行的一些参数
    214       params={
    215           'data_format': data_format,#数据类型
    216       })
    217 
    218   #这里定义了两个内部函数,只能被这个语句块的内部调用
    219   # Set up training and evaluation input functions.
    220   def train_input_fn():
    221     """Prepare data for training."""
    222 
    223     # When choosing shuffle buffer sizes, larger sizes result in better
    224     # randomness, while smaller sizes use less memory. MNIST is a small
    225     # enough dataset that we can easily shuffle the full epoch.
    226     ds = dataset.train(flags_obj.data_dir)
    227     ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size)
    228 
    229     # Iterate through the dataset a set number (`epochs_between_evals`) of times
    230     # during each training session.
    231     ds = ds.repeat(flags_obj.epochs_between_evals)
    232     return ds
    233 
    234   def eval_input_fn():
    235     return dataset.test(flags_obj.data_dir).batch(
    236         flags_obj.batch_size).make_one_shot_iterator().get_next()
    237 
    238   # Set up hook that outputs training logs every 100 steps.
    239   train_hooks = hooks_helper.get_train_hooks(
    240       flags_obj.hooks, model_dir=flags_obj.model_dir,
    241       batch_size=flags_obj.batch_size)
    242 
    243   # Train and evaluate model.
    244   for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals):
    245     #训练一次,验证一次
    246     mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks)
    247     eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)
    248     print('
    Evaluation results:
    	%s
    ' % eval_results)
    249 
    250     #如果eval_results['accuracy'] >= flags_obj.stop_threshold 说明模型训练好了
    251     if model_helpers.past_stop_threshold(flags_obj.stop_threshold,
    252                                          eval_results['accuracy']):
    253       break
    254 
    255   # Export the model
    256   if flags_obj.export_dir is not None:
    257     #预分配内存,等待数据进入
    258     image = tf.placeholder(tf.float32, [None, 28, 28])
    259     input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
    260         'image': image,
    261     })
    262     #输出模型
    263     mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn)
    264 
    265 
    266 def main(_):
    267   run_mnist(flags.FLAGS)
    268 
    269 
    270 if __name__ == '__main__':
    271   #日志
    272   tf.logging.set_verbosity(tf.logging.INFO)
    273   #给flags.FLAGS添加了很多参数项目
    274   define_mnist_flags()
    275   #带参数的启动
    276   absl_app.run(main)

    可以图形化看到的东西坚决不会用命令行ORZ

  • 相关阅读:
    Python3爬取前程无忧数据分析工作并存储到MySQL
    MySQL操作数据库和表的基本语句(DDL
    MyBatis的增删改查操作
    指定方向或者位置移动
    AI-Tank
    转载人家写的CURSOR
    Ajax学习整理笔记
    全面解析注解
    java调用存储过程mysql
    JAVA如何调用mysql写的存储过程
  • 原文地址:https://www.cnblogs.com/IGNB/p/10616359.html
Copyright © 2011-2022 走看看