zoukankan      html  css  js  c++  java
  • tensorflow的MNIST教程解析

    原址:http://tensorfly.cn/tfdoc/tutorials/mnist_pros.html

    eval()函数:计算字符串中有效的表达式;将字符串转化为相应的对象;解析字符串有效的表达式.

    1.InteractiveSession()和Session()的区别

    tf.InteractiveSession()默认自己就是用户要操作的session,而tf.Session()没有这个默认,因此用eval()启动计算时需要指明session。

    2.tensorflow计算原理

    为了在Python中进行高效的数值计算,我们通常会使用像NumPy一类的库,将一些诸如矩阵乘法的耗时操作在Python环境的外部来计算,这些计算通常会通过其它语言并用更为高效的代码来实现。

    但遗憾的是,每一个操作切换回Python环境时仍需要不小的开销。如果你想在GPU或者分布式环境中计算时,这一开销更加可怖,这一开销主要可能是用来进行数据迁移。

    TensorFlow也是在Python外部完成其主要工作,但是进行了改进以避免这种开销。其并没有采用在Python外部独立运行某个耗时操作的方式,而是先让我们描述一个交互操作图,然后完全将其运行在Python外部。这与Theano或Torch的做法类似。

    因此Python代码的目的是用来构建这个可以在外部运行的计算图,以及安排计算图的哪一部分应该被运行。详情请查看基本用法中的计算图表一节。

    3.占位符

    x = tf.placeholder("float", shape=[None, 784])
    y_ = tf.placeholder("float", shape=[None, 10])

    这里的xy并不是特定的值,相反,他们都只是一个占位符,可以在TensorFlow运行某一计算时根据该占位符输入具体的值

    4.变量

    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))

    我们现在为模型定义权重W和偏置b。可以将它们当作额外的输入量,但是TensorFlow有一个更好的处理方式:变量。一个变量代表着TensorFlow计算图中的一个值,能够在计算过程中使用,甚至进行修改。在机器学习的应用过程中,模型参数一般用Variable来表示。

    变量使用前,需要初始化

    sess.run(tf.initialize_all_variables())

    5.类别预测与损失函数

    y = tf.nn.softmax(tf.matmul(x,W) + b)

    计算回归模型(预测值),向量化后的图片x和权重矩阵W相乘,加上偏置b,然后计算每个分类的softmax概率值。

    cross_entropy = -tf.reduce_sum(y_*tf.log(y))

    我们的损失函数是目标类别和预测类别之间的交叉熵。y_*tf.log(y)计算的是交叉熵,tf.reduce_sum()求均值.

    6.训练模型

    我们已经定义好模型和训练用的损失函数,那么用TensorFlow进行训练就很简单了。因为TensorFlow知道整个计算图,它可以使用自动微分法找到对于各个变量的损失的梯度值。TensorFlow有大量内置的优化算法 这个例子中,我们用最速下降法让交叉熵下降,步长为0.01.

    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

    这一行代码实际上是用来往计算图上添加一个新操作,其中包括计算梯度,计算每个参数的步长变化,并且计算出新的参数值。

    返回的train_step操作对象,在运行时会使用梯度下降来更新参数。因此,整个模型的训练可以通过反复地运行train_step来完成。

    for i in range(1000):
      batch = mnist.train.next_batch(50)
      train_step.run(feed_dict={x: batch[0], y_: batch[1]})

    每一步迭代,我们都会加载50个训练样本,然后执行一次train_step,并通过feed_dictxy_张量占位符用训练训练数据替代。运行的过程中,插入新的数据.

    7.评估模型

    先让我们找出那些预测正确的标签。tf.argmax 是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签,比如tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1) 代表正确的标签,我们可以用 tf.equal 来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)。

    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

    这里返回一个布尔数组。为了计算我们分类的准确率,我们将布尔值转换为浮点数来代表对、错,然后取平均值。例如:[True, False, True, True]变为[1,0,1,1],计算出平均值为0.75。整体相似程度.

    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

    最后,我们可以计算出在测试数据上的准确率,大概是91%。

    print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  • 相关阅读:
    mac快捷键,pycharm快捷键
    Django进阶之session
    Python:如何将字符串作为变量名
    Ubuntu中创建用户
    redis在centos上的安装
    centos--网络配置问题,提示connect: Network is unreachable
    Python 3.x--paramiko模块详解
    Python 3.x--paramiko模块安装过程中的错误
    Python 3.x--Socket实现简单的ssh和文件下载功能
    Python 3.x--面向对象编程(二)静态方法、类方法、属性方法
  • 原文地址:https://www.cnblogs.com/smartmsl/p/10903227.html
Copyright © 2011-2022 走看看