zoukankan      html  css  js  c++  java
  • Tensorflow学习笔记(二)——MNIST机器学习入门

    一、mnist手写识别代码运行过程

    1 下载数据库
    http://yann.lecun.com/exdb/mnist/下载上面提到的4个gz文件,放到本地目录如 /tmp/mnist

    2 下载input_data.py,放在/home/tmp/test目录下
    https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/g3doc/tutorials/mnist/input_data.py

    3 在/home/tmp/test目录下创建文件test_tensorflow_mnist.py,内容如下:

    !/usr/bin/env python

    import input_data
    import tensorflow as tf

    mnist = input_data.read_data_sets(“/tmp/mnist”, one_hot=True)

    x = tf.placeholder(“float”, [None, 784])
    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    y = tf.nn.softmax(tf.matmul(x,W) + b)
    y_ = tf.placeholder(“float”, [None,10])
    cross_entropy = -tf.reduce_sum(y_*tf.log(y))
    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)

    for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, “float”))
    print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})

    4 运行。大概之需要几秒钟时间,输出结果是91%左右。

    二、数据集

    1、数据
    60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)。这样的切分很重要,在机器学习模型设计时必须有一个单独的测试数据集不用于训练而是用来评估这个模型的性能,从而更加容易把设计的模型推广到其他数据集上(泛化)。

    2、输入X
    因此,在MNIST训练数据集中,mnist.train.images 是一个形状为 [60000, 784] 的张量,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。在此张量里的每一个元素,都表示某张图片里的某个像素的强度值,值介于0和1之间。

    3、输出y
    相对应的MNIST数据集的标签是介于0到9的数字,用来描述给定图片里表示的数字。为了用于这个教程,我们使标签数据是”one-hot vectors”。 一个one-hot向量除了某一位的数字是1以外其余各维度数字都是0。所以在此教程中,数字n将表示成一个只有在第n维度(从0开始)数字为1的10维向量。比如,标签0将表示成([1,0,0,0,0,0,0,0,0,0,0])。因此, mnist.train.labels 是一个 [60000, 10] 的数字矩阵。

    三、思路方法

    1、实现回归模型
    x = tf.placeholder(tf.float32,[None,784])
    x不是一个特定的值,而是一个占位符placeholder,我们在TensorFlow运行计算时输入这个值。我们希望能够输入任意数量的MNIST图像,每一张图展平成784维的向量。我们用2维的浮点数张量来表示这些图,这个张量的形状是[None,784 ]。(这里的None表示此张量的第一个维度可以是任何长度的。)

    W = tf.Variable(tf.zeros([784,10]))
    b = tf.Variable(tf.zeros([10]))
    y = tf.nn.softmax(tf.matmul(x,W) + b)
    实现 y=W*x+b

    2、训练模型
    cross_entropy = -tf.reduce_sum(y_*tf.log(y))
    一个非常常见的,非常漂亮的成本函数是“交叉熵”(cross-entropy)。交叉熵产生于信息论里面的信息压缩编码技术,但是它后来演变成为从博弈论到机器学习等其他领域里的重要技术手段。
    首先,用 tf.log 计算 y 的每个元素的对数。接下来,我们把 y_ 的每一个元素和 tf.log(y) 的对应元素相乘。最后,用 tf.reduce_sum 计算张量的所有元素的总和。(注意,这里的交叉熵不仅仅用来衡量单一的一对预测和真实值,而是所有100幅图片的交叉熵的总和。对于100个数据点的预测表现比单一数据点的表现能更好地描述我们的模型的性能。

    train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    init = tf.initialize_all_variables()
    在这里,我们要求TensorFlow用梯度下降算法(gradient descent algorithm)以0.01的学习速率最小化交叉熵。梯度下降算法(gradient descent algorithm)是一个简单的学习过程,TensorFlow只需将每个变量一点点地往使成本不断降低的方向移动。当然TensorFlow也提供了其他许多优化算法:只要简单地调整一行代码就可以使用其他的算法。
    TensorFlow在这里实际上所做的是,它会在后台给描述你的计算的那张图里面增加一系列新的计算操作单元用于实现反向传播算法和梯度下降算法。然后,它返回给你的只是一个单一的操作,当运行这个操作时,它用梯度下降算法训练你的模型,微调你的变量,不断减少成本。

    sess = tf.Session()
    sess.run(init)
    初始化所有参数。

    for i in range(1000):
    batch_xs,batch_ys = mnist.train.next_batch(100)
    sess.run(train_step,feed_dict={x:batch_xs,y_:batch_ys})
    循环1000次。该循环的每个步骤中,我们都会随机抓取训练数据中的100个批处理数据点,然后我们用这些数据点作为参数替换之前的占位符来运行train_step。
    使用一小部分的随机数据来进行训练被称为随机训练(stochastic training)- 在这里更确切的说是随机梯度下降训练。在理想情况下,我们希望用我们所有的数据来进行每一步的训练,因为这能给我们更好的训练结果,但显然这需要很大的计算开销。所以,每一次训练我们可以使用不同的数据子集,这样做既可以减少计算开销,又可以最大化地学习到数据集的总体特性。

    3、评估我们的模型
    correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
    这行代码会给我们一组布尔值。为了确定正确预测项的比例,我们可以把布尔值转换成浮点数,然后取平均值。例如,[True, False, True, True] 会变成 [1,0,1,1] ,取平均值后得到 0.75.

    accuracy = tf.reduce_mean(tf.cast(correct_prediction,”float”))
    最后,我们计算所学习到的模型在测试数据集上面的正确率。

    print sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
    输出这个最终结果值应该大约是91%。

  • 相关阅读:
    读书笔记第四章
    读书笔记第三章
    读书笔记第二章
    读书笔记第一章
    第十章 读书笔记
    第九章 读书笔记
    第八章读书笔记
    第七章读书笔记
    第六章读书笔记
    第五章读书笔记
  • 原文地址:https://www.cnblogs.com/iriswang/p/11084658.html
Copyright © 2011-2022 走看看