zoukankan      html  css  js  c++  java
  • TF启程

    我第一次开始接触到TensorFlow大概是去年五月份,大三下,如果一年多已过,我却还在写启程。。这进度,实在汗颜。。

    一个完整的tensorflow程序可以分为以下几部分:

    • Inputs and Placeholders
    • Build the Graph
      • Inference
      • Loss
    • Training
    • Train the Model
    • Visualize the Status
    • Save a Checkpoint
    • Evaluate the Model
    • Build the Eval Graph
    • Eval Output

    Inputs and Placeholders

    对于一个完整的网络来说,必定有输入还有输出,而Placeholders就是针对网络输入来的,相当于预先给输入变量占个坑,拿mnist来说,占坑代码可以如下面的例子:

    1 images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,mnist.IMAGE_PIXELS))
    3 labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))

    上述代码相当于为mnist图片和标签分别占坑,而tf.placeholder参数可以如下面所示:

    tf.placeholder(dtype, shape=None, name=None)

    即需要提供占坑数据类型dtype,占坑数据shape,当然也可以给它提供一个唯一的name

    Build the Graph

    因为tf是通过构建图模型来进行网络搭建的,因此搭建网络也就是’Build the Graph’。

    Inference

    首先就是构建图,利用一系列符号将要表达的操作表达清楚,以用于后续模型的训练。如下面代码:

    1 with tf.name_scope('hidden1'):
    2     weights = tf.Variable(tf.truncated_normal([IMAGE_PIXELS, hidden1_units],
    3         stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),name='weights')
    4 
    5     biases = tf.Variable(tf.zeros([hidden1_units]),
    6     name='biases')

    如上述代码,对于一个图的搭建,需要一些变量来支持我们的运算,比如矩阵相乘等,需要通过tf.Variable来声明变量,其参数格式如下:

    1 tf.Variable(self, initial_value=None, trainable=True, collections=None, validate_shape=True,
    2     caching_device=None, name=None, variable_def=None, dtype=None)

    需要提供变量初始值initial_value, 是否接受训练trainable,对于validate_shape表示该变量是否可以改变,如果形状可以改变,那么应该为False。对于每个变量,可以赋予不同的名字tf.name_scope

    Loss

    在定义完图结构之后,我们需要有个目标函数,用作更新图结构中的各个变量。

    1 labels = tf.to_int64(labels)
    3 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='xentropy')

    如上,通过给定的labels占坑变量,完成手写数字识别的最后交叉熵函数。

    Training

    在得到目标函数之后,我们就可以对模型进行训练,这里常用梯度下降法。在训练阶段,我们可以通过tf.scalar_summary来实现变量的记录,用作后续的tensorboard的可视化,如:

    1 tf.scalar_summary(loss.op.name, loss)

    然后通过tf.SummaryWriter()来得到对应的提交值。

    而对于模型的最优化,这里 tf 提供了很多optimazer,通常在tf.train里面,这里常用的是GradientDenscentOptimizer(lr),然后通过调用:

    1 train_op = optimizer.minimize(loss, global_step=global_step)

    Train the Model

    在模型训练时,我们需要打开一个默认的图环境,用作训练,如:

    1 with tf.Graph().as_default():

    以此来打开一个图结构,然后我们需要声明一个会话在所有操作都定义完毕之后,这样我们就可以利用这个session来运行Graph.可以通过如下方法声明:

    1 with tf.Session() as sess:
    2     init = tf.initialize_all_variables()
    3     sess.run(init)

    每次我们可以通过sess.run来运行一些操作,进而获取其输出值,

    1 sess.run(fetches, feed_dict=None, options=None, run_metadata=None)

    可以看到,run需要fetches,即操作,feed_dictfetches的输入,即占坑变量与其对应值构成的字典。

    Visualize the Status

    当然,在运行过程,我们可以通过可视化的操作来看网络运行情况。
    在之前的tf.scalar_summary, 我们可以通过:

    1 summary = tf.merge_all_summaries()

    将在图构建阶段的变量收集起来,然后在session创建之后运行如下命令生成可视化的值。

    1 summary_str = sess.run(summary, feed_dict=feed_dict)
    2 summary_writer.add_summary(summary_str, step)

    其中summary_writer由如下得到:

    1 summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    然后用tensorboard打开对应文件即可。

    Save a Chenckpoint

    对于模型的保存,可以通过如下代码实现:

    1 saver = tf.train.Saver()  
    2 saver.save(sess, FLAGS.train_dir, global_step=step)

    而载入模型可以通过如下的代码来实现:

    saver.restore(sess, FLAGS.train_dir)

    当然了,模型的估计就类似上述了。

    这样简单的模型搭建到运行就完成了。本文主要用到这些函数:

      • tf.placeholder
      • tf.Variable
      • tf.train
        • tf.train.GradientDenscentOptimizer
        • tf.train.SummaryWriter
        • tf.train.Saver
      • tf.session
      • tf.Graph
      • tf.add_summary
      • tf.merge_all_summaries

    其实构建一个模型基本就用这些函数,然后就是一些数理计算方法。详情参看tensorflow

    Reference:Link

  • 相关阅读:
    Spark Netty与Jetty (源码阅读十一)
    Netty服务端与客户端(源码一)
    NIO源码阅读
    Spark之SQL解析(源码阅读十)
    Spark BlockManager的通信及内存占用分析(源码阅读九)
    Spark Job的提交与task本地化分析(源码阅读八)
    Spark Shuffle数据处理过程与部分调优(源码阅读七)
    Spark常用函数(源码阅读六)
    Spark数据传输及ShuffleClient(源码阅读五)
    SparkConf加载与SparkContext创建(源码阅读四)
  • 原文地址:https://www.cnblogs.com/niuxichuan/p/9147638.html
Copyright © 2011-2022 走看看