节点(Node)表示数学操作,多维数据数组,也就是张量(tensor),由线(edges)联系,表示节点之间的输入输出关系
计算图computational graph是TF中很重要的一个概念,其是由一系列节点(nodes)组成的图模型,每个节点对应的是TF的一个算子(operation)。每个算子会有输入与输出,并且输入和输出都是张量。所以我们使用TF的算子可以构建自己的深度学习模型,其背后的就是一个计算图。还有一点这个计算图是静态的,意思是这个计算图每个节点接收什么样的张量和输出什么样的张量已经固定下来。要运行这个计算图,你需要开启一个会话(session),在session中这个计算图才可以真正运行。
会话(Session):为了获得图的计算结果,图必须在会话中被启动。图是会话类型的一个成员,会话类型还包括一个runner,负责执行这张图。会话的主要任务是在图运算时分配CPU或GPU。
Runner:在建立图之后,必须使用会话中的Runner来运行图,才能得到结果。在运行图时,需要为所有的变量和占位符赋值,否则就会报错。
TensorflowSharp中的几类主要变量
- Const:常量,这很好理解。它们在定义时就必须被赋值,而且值永远无法被改变。
- Placeholder:占位符。这是一个在定义时不需要赋值,但在使用之前必须赋值(feed)的变量,通常用作训练数据。
- Variable:变量,它和占位符的不同是它在定义时需要赋值,而且它的数值是可以在图的计算过程中随时改变的。因此,占位符通常用作图的输入(即训练数据),而变量用作图中可以被“训练”或“学习”的那些tensor,例如y=ax+b中的a和b。
"""A very simple MNIST classifier.
See extensive documentation at
http://tensorflow.org/tutorials/mnist/beginners/index.md
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Import data
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('data_dir', '/tmp/data/', 'Directory for storing data') # 把数据放在/tmp/data文件夹中
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) # 读取数据集
# 建立抽象模型
x = tf.placeholder(tf.float32, [None, 784]) # 占位符
y = tf.placeholder(tf.float32, [None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
a = tf.nn.softmax(tf.matmul(x, W) + b)
# 定义损失函数和训练方法
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(a), reduction_indices=[1])) # 损失函数为交叉熵
optimizer = tf.train.GradientDescentOptimizer(0.5) # 梯度下降法,学习速率为0.5
train = optimizer.minimize(cross_entropy) # 训练目标:最小化损失函数
# Test trained model
correct_prediction = tf.equal(tf.argmax(a, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# Train
sess = tf.InteractiveSession() # 建立交互式会话
tf.initialize_all_variables().run()
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
train.run({x: batch_xs, y: batch_ys})
print(sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}))