zoukankan      html  css  js  c++  java
  • 基于tensorflow的MNIST手写识别

    这个例子,是学习tensorflow的人员通常会用到的,也是基本的学习曲线中的一环。我也是!

    这个例子很简单,这里,就是简单的说下,不同的tensorflow版本,相关的接口函数,可能会有不一样哟。在TensorFlow的中文介绍文档中的内容,有些可能与你使用的tensorflow的版本不一致了,我这里用到的tensorflow的版本就有这个问题。 另外,还给大家说下,例子中的MNIST所用到的资源图片,在原始的官网上,估计很多人都下载不到了。我也提供一下下载地址。

    我的tensorflow的版本信息:

    >>> import tensorflow as tf
    >>> print tf.VERSION    
    1.0.1
    >>> print tf.GIT_VERSION
    v1.0.0-65-g4763edf-dirty
    >>> print tf.COMPILER_VERSION
    4.8.4

    下面,就看看,我参考的中文tensorflow网站的代码,在自己的环境里,运行的结果。

     1 [root@bogon tensorflow]# python
     2 Python 2.7.5 (default, Nov  6 2016, 00:28:07) 
     3 [GCC 4.8.5 20150623 (Red Hat 4.8.5-11)] on linux2
     4 Type "help", "copyright", "credits" or "license" for more information.
     5 >>> import tensorflow.examples.tutorials.mnist.input_data as input_data
     6 >>> mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
     7 Traceback (most recent call last):
     8   File "<stdin>", line 1, in <module>
     9   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py", line 211, in read_data_sets
    10     SOURCE_URL + TRAIN_IMAGES)
    11   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 208, in maybe_download
    12     temp_file_name, _ = urlretrieve_with_retry(source_url)
    13   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 165, in wrapped_fn
    14     return fn(*args, **kwargs)
    15   File "/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py", line 190, in urlretrieve_with_retry
    16     return urllib.request.urlretrieve(url, filename)
    17   File "/usr/lib64/python2.7/urllib.py", line 94, in urlretrieve
    18     return _urlopener.retrieve(url, filename, reporthook, data)
    19   File "/usr/lib64/python2.7/urllib.py", line 240, in retrieve
    20     fp = self.open(url, data)
    21   File "/usr/lib64/python2.7/urllib.py", line 203, in open
    22     return self.open_unknown_proxy(proxy, fullurl, data)
    23   File "/usr/lib64/python2.7/urllib.py", line 222, in open_unknown_proxy
    24     raise IOError, ('url error', 'invalid proxy for %s' % type, proxy)
    25 IOError: [Errno url error] invalid proxy for http: '10.90.1.101:8080'
    26 >>> 
    27 >>> mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    28 Extracting MNIST_data/train-images-idx3-ubyte.gz
    29 Extracting MNIST_data/train-labels-idx1-ubyte.gz
    30 Extracting MNIST_data/t10k-images-idx3-ubyte.gz
    31 Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
    32 >>> import tensorflow as tf
    33 >>> x = tf.placeholder(tf.float32, [None, 784])
    34 >>> W = tf.Variable(tf.zeros([784,10]))
    35 >>> b = tf.Variable(tf.zeros([10]))
    36 >>> y = tf.nn.softmax(tf.matmul(x,W) + b)
    37 >>> y_ = tf.placeholder("float", [None,10])
    38 >>> cross_entropy = -tf.reduce_sum(y_*tf.log(y))
    39 >>> train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
    40 >>> init = tf.initialize_all_variables()
    41 WARNING:tensorflow:From <stdin>:1: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
    42 Instructions for updating:
    43 Use `tf.global_variables_initializer` instead.
    44 >>> init = tf.global_variables_initializer()   
    45 >>> sess = tf.Session()
    46 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE3 instructions, but these are available on your machine and could speed up CPU computations.
    47 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
    48 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
    49 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
    50 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
    51 W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
    52 >>> sess.run(init)
    53 >>> for i in range(1000):
    54 ...   batch_xs, batch_ys = mnist.train.next_batch(100)
    55 ...   sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
    56 ... 
    57 >>> correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    58 >>> accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    59 >>> print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
    60 0.9088
    61 >>> 

    上述日志,是我的测试全过程记录,上面反映的信息有如下几点:

    1. 红色部分的错误,因为我本地机器是通过代理上网的,这个过程中,tensorflow会用urllib进行MNIST的图片资源的下载,由于网络问题,资源文件下载失败。

    2. 都有哪些资源文件要下载呢?追踪日志中的文件/usr/lib/python2.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py第211行前后:

    def read_data_sets(train_dir,
                       fake_data=False,
                       one_hot=False,
                       dtype=dtypes.float32,
                       reshape=True,
                       validation_size=5000):
      if fake_data:
    
        def fake():
          return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
    
        train = fake()
        validation = fake()
        test = fake()
        return base.Datasets(train=train, validation=validation, test=test)
    
      TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
      TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
      TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
      TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
    
      local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
                                       SOURCE_URL + TRAIN_IMAGES)
      with open(local_file, 'rb') as f:
        train_images = extract_images(f)
    
      local_file = base.maybe_download(TRAIN_LABELS, train_dir,
                                       SOURCE_URL + TRAIN_LABELS)
      with open(local_file, 'rb') as f:
        train_labels = extract_labels(f, one_hot=one_hot)
    
      local_file = base.maybe_download(TEST_IMAGES, train_dir,
                                       SOURCE_URL + TEST_IMAGES)
      with open(local_file, 'rb') as f:
        test_images = extract_images(f)
    
      local_file = base.maybe_download(TEST_LABELS, train_dir,
                                       SOURCE_URL + TEST_LABELS)
      with open(local_file, 'rb') as f:
        test_labels = extract_labels(f, one_hot=one_hot)
    
      if not 0 <= validation_size <= len(train_images):
        raise ValueError(
            'Validation size should be between 0 and {}. Received: {}.'
            .format(len(train_images), validation_size))
    
      validation_images = train_images[:validation_size]
      validation_labels = train_labels[:validation_size]
      train_images = train_images[validation_size:]
      train_labels = train_labels[validation_size:]
    
      train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
      validation = DataSet(validation_images,
                           validation_labels,
                           dtype=dtype,
                           reshape=reshape)
      test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)
    
      return base.Datasets(train=train, validation=validation, test=test)

    看到上面红色的部分,就是这里需要下载的图片资源文件。这个,我的网络环境是下载不了的。我通过其他途径下载到了这里需要的资源。我将下载的图片资源,放在了我进入python时所在的路径下。虽然直接下载没有成功,但是在当前路径下还是创建了MNIST_data的目录的。如下图,红色圈目录就是程序创建的目录。我将下载的train-images-idx3-ubyte.gz,train-labels-idx1-ubyte.gz,t10k-images-idx3-ubyte.gz,t10k-labels-idx1-ubyte.gz放在MNIST_data目录了

    然后,再次执行mnist = input_data.read_data_sets("MNIST_data/", one_hot=True),就ok了,不会报错。得到28-31行的输出信息。

    3. 执行到第40行的代码时,爆出WARNING,提示用新的函数,按照提示信息,执行了第41行的代码,OK。说明版本兼容性,在tensorflow中需要注意

    4. 执行后,得到结果,如60行显示,识别率为0.9088。

    关于MNIST的这个例子的手写识别性能的理论,不是本博文的重点,读者可以参照MNIST相关的文章自行学习。

    最后,附上MNIST这个例子中,用到的资源图片下载地址,点击进行下载。(说明:需要积分才能下载的,谅解)

  • 相关阅读:
    洛谷 P1194 飞扬的小鸟 题解
    洛谷 P1197 星球大战 题解
    洛谷 P1879 玉米田Corn Fields 题解
    洛谷 P2796 Facer的程序 题解
    洛谷 P2398 GCD SUM 题解
    洛谷 P2051 中国象棋 题解
    洛谷 P1472 奶牛家谱 Cow Pedigrees 题解
    洛谷 P1004 方格取数 题解
    洛谷 P2331 最大子矩阵 题解
    洛谷 P1073 最优贸易 题解
  • 原文地址:https://www.cnblogs.com/shihuc/p/6599170.html
Copyright © 2011-2022 走看看