zoukankan      html  css  js  c++  java
  • TFRecord读写简介+Demo 基于Ubuntu18.04+Tensorflow1.12 无WARNING

    简介

    • TFRecord是TensorFlow官方推荐使用的数据格式化存储工具。
    • 它规范了数据的读写方式。
    • 只要生成一次TFRecord,之后的数据读取和加工处理的效率都会得到提高。

    将图片转换成TFRecord

    本例,将fashion-MNIST数据转换成TFRecord,需要先下载fashion数据集到当前目录下,参考:https://github.com/zalandoresearch/fashion-mnist/tree/master/data/fashion

    import numpy as np
    import tensorflow as tf
    import gzip
    import os
    
    fashion_mnist_directory = './data/fashion/'
    
    def load_mnist(path, kind='train'):
        labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind)
        images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind)
    
        with gzip.open(labels_path, 'rb') as lbpath:
            labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)
    
        with gzip.open(images_path, 'rb') as imgpath:
            images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(-1, 784)
    
        print(labels_path, "shape =", labels.shape)
        print(images_path, "shape =", images.shape)
    
        return images, labels
    
    
    def make_example(image, label):
        return tf.train.Example(features=tf.train.Features(feature={
            'image_raw' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tobytes()])),
            'label' :     tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)     ])) }))
    
    
    def write_tfrecord(images, labels, filename):
        writer = tf.python_io.TFRecordWriter(filename)
        for image, label, k in zip(images, labels, range(labels.shape[0])):
            exam = make_example(image, label)
            writer.write(exam.SerializeToString())
            if (k%100 == 0):
                print("
    writing", filename, "%6.2f%% complited." %(100.0*(k+1)/labels.shape[0]), end='')
        
        print("
    writing", filename, "%6.2f%% complited." %(100.0))
        writer.close()
    
    
    def main():
        train_images, train_labels = load_mnist(fashion_mnist_directory, 'train')
        test_images, test_labels   = load_mnist(fashion_mnist_directory, 't10k')
        
        write_tfrecord(train_images, train_labels, 'fashion_mnist_train.tfrecords')
        write_tfrecord(test_images, test_labels, 'fashion_mnist_test.tfrecords')
        
    if __name__ == '__main__':
        main()

    读取TFRecord数据来训练

    以下代码读取TFRecord数据用于训练,改代码改编自官方例程:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/how_tos/reading_data

    原始代码运行时报错,已修复。

    注意:在这个例子中,_, loss_value = sess.run([train_op, loss]),只执行一次Batch Input,无论[]中是什么,有多少个操作。

    import argparse
    import os.path
    import sys
    import time
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import mnist
    
    FLAGS = None
    
    TRAIN_FILE = 'fashion_mnist_train.tfrecords'
    VALIDATION_FILE = 'fashion_mnist_test.tfrecords'
    
    
    def decode(serialized_example):
        features = tf.parse_single_example(serialized_example,
                                           features={'image_raw': tf.FixedLenFeature([], tf.string),
                                                     'label':     tf.FixedLenFeature([], tf.int64)})
        image = tf.decode_raw(features['image_raw'], tf.uint8)
        image.set_shape((mnist.IMAGE_PIXELS))
        label = tf.cast(features['label'], tf.int32)
        return image, label
    
    
    def augment(image, label):
        """Placeholder for data augmentation."""
        # OPTIONAL: Could reshape into a 28x28 image and apply distortions here.
        return image, label
    
    
    def normalize(image, label):
        """Convert `image` from [0, 255] -> [-0.5, 0.5] floats."""
        image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
        return image, label
    
    
    def inputs(train, batch_size, num_epochs):
        """Reads input data"""
        if not num_epochs:
            num_epochs = None
        filename = os.path.join(FLAGS.train_dir, TRAIN_FILE if train else VALIDATION_FILE)
    
        with tf.name_scope('input'):
            dataset = tf.data.TFRecordDataset(filename)
            dataset = dataset.map(decode)
            dataset = dataset.map(augment)
            dataset = dataset.map(normalize)
            dataset = dataset.shuffle(1000 + 3 * batch_size)
            dataset = dataset.repeat(num_epochs)
            dataset = dataset.batch(batch_size)
            iterator = dataset.make_one_shot_iterator()
        return iterator.get_next()
    
    
    def run_training():
        with tf.Graph().as_default():
            image_batch, label_batch = inputs(train=True,
                                              batch_size=FLAGS.batch_size,
                                              num_epochs=FLAGS.num_epochs)
            logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2)
            loss = mnist.loss(logits, label_batch)
            train_op = mnist.training(loss, FLAGS.learning_rate)
            
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
                               
            with tf.Session() as sess:
                sess.run(init_op)
                try:
                    step = 0
                    while True:  # Train until OutOfRangeError
                        start_time = time.time()
                        _, loss_value = sess.run([train_op, loss])
                        duration = time.time() - start_time
                        if step % 100 == 0:
                            print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
                        step += 1
                except tf.errors.OutOfRangeError:
                    print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
    
    
    def main(_):
        run_training()
    
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument('--learning_rate', type=float, default=0.01, help='Initial learning rate.')
        parser.add_argument('--num_epochs',    type=int,   default=2,    help='Number of epochs to run trainer.')
        parser.add_argument('--hidden1',       type=int,   default=128,  help='Number of units in hidden layer 1.')
        parser.add_argument('--hidden2',       type=int,   default=32,   help='Number of units in hidden layer 2.')
        parser.add_argument('--batch_size',    type=int,   default=100,  help='Batch size.')
        parser.add_argument('--train_dir',     type=str,   default='./', help='Directory with the training data.')
        FLAGS, unparsed = parser.parse_known_args()
        tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

    参考了:

    • https://blog.csdn.net/gg_18826075157/article/details/78449104 
    • https://github.com/zalandoresearch/fashion-mnist/blob/master/utils/mnist_reader.py
  • 相关阅读:
    变量的创建和初始化
    HDU 1114 Piggy-Bank (dp)
    HDU 1421 搬寝室 (dp)
    HDU 2059 龟兔赛跑 (dp)
    HDU 2571 命运 (dp)
    HDU 1574 RP问题 (dp)
    HDU 2577 How to Type (字符串处理)
    HDU 1422 重温世界杯 (dp)
    HDU 2191 珍惜现在,感恩生活 (dp)
    HH实习 acm算法部 1689
  • 原文地址:https://www.cnblogs.com/xbit/p/10071848.html
Copyright © 2011-2022 走看看