zoukankan      html  css  js  c++  java
  • 下载MNIST数据集脚本input_data源码

    # Copyright 2015 Google Inc. All Rights Reserved.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    # http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    # ==============================================================================
    """Functions for downloading and reading MNIST data."""
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    import gzip
    import os
    import tensorflow.python.platform
    import numpy
    from six.moves import urllib
    from six.moves import xrange # pylint: disable=redefined-builtin
    import tensorflow as tf
    SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
    def maybe_download(filename, work_directory):
    """Download the data from Yann's website, unless it's already here."""
    if not os.path.exists(work_directory):
    os.mkdir(work_directory)
    filepath = os.path.join(work_directory, filename)
    if not os.path.exists(filepath):
    filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
    statinfo = os.stat(filepath)
    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
    return filepath
    def _read32(bytestream):
    dt = numpy.dtype(numpy.uint32).newbyteorder('>')
    return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
    def extract_images(filename):
    """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
    print('Extracting', filename)
    with gzip.open(filename) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
    raise ValueError(
    'Invalid magic number %d in MNIST image file: %s' %
    (magic, filename))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data
    def dense_to_one_hot(labels_dense, num_classes=10):
    """Convert class labels from scalars to one-hot vectors."""
    num_labels = labels_dense.shape[0]
    index_offset = numpy.arange(num_labels) * num_classes
    labels_one_hot = numpy.zeros((num_labels, num_classes))
    labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
    return labels_one_hot
    def extract_labels(filename, one_hot=False):
    """Extract the labels into a 1D uint8 numpy array [index]."""
    print('Extracting', filename)
    with gzip.open(filename) as bytestream:
    magic = _read32(bytestream)
    if magic != 2049:
    raise ValueError(
    'Invalid magic number %d in MNIST label file: %s' %
    (magic, filename))
    num_items = _read32(bytestream)
    buf = bytestream.read(num_items)
    labels = numpy.frombuffer(buf, dtype=numpy.uint8)
    if one_hot:
    return dense_to_one_hot(labels)
    return labels
    class DataSet(object):
    def __init__(self, images, labels, fake_data=False, one_hot=False,
    dtype=tf.float32):
    """Construct a DataSet.
    one_hot arg is used only if fake_data is true. `dtype` can be either
    `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
    `[0, 1]`.
    """
    dtype = tf.as_dtype(dtype).base_dtype
    if dtype not in (tf.uint8, tf.float32):
    raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
    dtype)
    if fake_data:
    self._num_examples = 10000
    self.one_hot = one_hot
    else:
    assert images.shape[0] == labels.shape[0], (
    'images.shape: %s labels.shape: %s' % (images.shape,
    labels.shape))
    self._num_examples = images.shape[0]
    # Convert shape from [num examples, rows, columns, depth]
    # to [num examples, rows*columns] (assuming depth == 1)
    assert images.shape[3] == 1
    images = images.reshape(images.shape[0],
    images.shape[1] * images.shape[2])
    if dtype == tf.float32:
    # Convert from [0, 255] -> [0.0, 1.0].
    images = images.astype(numpy.float32)
    images = numpy.multiply(images, 1.0 / 255.0)
    self._images = images
    self._labels = labels
    self._epochs_completed = 0
    self._index_in_epoch = 0
    @property
    def images(self):
    return self._images
    @property
    def labels(self):
    return self._labels
    @property
    def num_examples(self):
    return self._num_examples
    @property
    def epochs_completed(self):
    return self._epochs_completed
    def next_batch(self, batch_size, fake_data=False):
    """Return the next `batch_size` examples from this data set."""
    if fake_data:
    fake_image = [1] * 784
    if self.one_hot:
    fake_label = [1] + [0] * 9
    else:
    fake_label = 0
    return [fake_image for _ in xrange(batch_size)], [
    fake_label for _ in xrange(batch_size)]
    start = self._index_in_epoch
    self._index_in_epoch += batch_size
    if self._index_in_epoch > self._num_examples:
    # Finished epoch
    self._epochs_completed += 1
    # Shuffle the data
    perm = numpy.arange(self._num_examples)
    numpy.random.shuffle(perm)
    self._images = self._images[perm]
    self._labels = self._labels[perm]
    # Start next epoch
    start = 0
    self._index_in_epoch = batch_size
    assert batch_size <= self._num_examples
    end = self._index_in_epoch
    return self._images[start:end], self._labels[start:end]
    def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):
    class DataSets(object):
    pass
    data_sets = DataSets()
    if fake_data:
    def fake():
    return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
    data_sets.train = fake()
    data_sets.validation = fake()
    data_sets.test = fake()
    return data_sets
    TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
    TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
    TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
    TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
    VALIDATION_SIZE = 5000
    local_file = maybe_download(TRAIN_IMAGES, train_dir)
    train_images = extract_images(local_file)
    local_file = maybe_download(TRAIN_LABELS, train_dir)
    train_labels = extract_labels(local_file, one_hot=one_hot)
    local_file = maybe_download(TEST_IMAGES, train_dir)
    test_images = extract_images(local_file)
    local_file = maybe_download(TEST_LABELS, train_dir)
    test_labels = extract_labels(local_file, one_hot=one_hot)
    validation_images = train_images[:VALIDATION_SIZE]
    validation_labels = train_labels[:VALIDATION_SIZE]
    train_images = train_images[VALIDATION_SIZE:]
    train_labels = train_labels[VALIDATION_SIZE:]
    data_sets.train = DataSet(train_images, train_labels, dtype=dtype)
    data_sets.validation = DataSet(validation_images, validation_labels,
    dtype=dtype)
    data_sets.test = DataSet(test_images, test_labels, dtype=dtype)
    return data_sets

  • 相关阅读:
    DVWA 黑客攻防演练(十)反射型 XSS 攻击 Reflected Cross Site Scripting
    DVWA 黑客攻防演练(九) SQL 盲注 SQL Injection (Blind)
    DVWA 黑客攻防演练(八)SQL 注入 SQL Injection
    DVWA 黑客攻防演练(七)Weak Session IDs
    DVWA 黑客攻防演练(六)不安全的验证码 Insecure CAPTCHA
    DVWA 黑客攻防演练(五)文件上传漏洞 File Upload
    工作流表结构设计
    Visual Studio 2019尝鲜----新建空项目体验
    《使用CSLA 2019:CSLA .NET概述》原版和机译文档下载
    .NET快速开发平台的在线预览
  • 原文地址:https://www.cnblogs.com/shouhutian/p/9799141.html
Copyright © 2011-2022 走看看