zoukankan      html  css  js  c++  java
  • TFRecord读写简介+Demo 基于Ubuntu18.04+Tensorflow1.12 无WARNING

    简介

    • TFRecord是TensorFlow官方推荐使用的数据格式化存储工具。
    • 它规范了数据的读写方式。
    • 只要生成一次TFRecord,之后的数据读取和加工处理的效率都会得到提高。

    将图片转换成TFRecord

    本例,将fashion-MNIST数据转换成TFRecord,需要先下载fashion数据集到当前目录下,参考:https://github.com/zalandoresearch/fashion-mnist/tree/master/data/fashion

    import numpy as np
    import tensorflow as tf
    import gzip
    import os
    
    fashion_mnist_directory = './data/fashion/'
    
    def load_mnist(path, kind='train'):
        labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind)
        images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind)
    
        with gzip.open(labels_path, 'rb') as lbpath:
            labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)
    
        with gzip.open(images_path, 'rb') as imgpath:
            images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(-1, 784)
    
        print(labels_path, "shape =", labels.shape)
        print(images_path, "shape =", images.shape)
    
        return images, labels
    
    
    def make_example(image, label):
        return tf.train.Example(features=tf.train.Features(feature={
            'image_raw' : tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tobytes()])),
            'label' :     tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)     ])) }))
    
    
    def write_tfrecord(images, labels, filename):
        writer = tf.python_io.TFRecordWriter(filename)
        for image, label, k in zip(images, labels, range(labels.shape[0])):
            exam = make_example(image, label)
            writer.write(exam.SerializeToString())
            if (k%100 == 0):
                print("
    writing", filename, "%6.2f%% complited." %(100.0*(k+1)/labels.shape[0]), end='')
        
        print("
    writing", filename, "%6.2f%% complited." %(100.0))
        writer.close()
    
    
    def main():
        train_images, train_labels = load_mnist(fashion_mnist_directory, 'train')
        test_images, test_labels   = load_mnist(fashion_mnist_directory, 't10k')
        
        write_tfrecord(train_images, train_labels, 'fashion_mnist_train.tfrecords')
        write_tfrecord(test_images, test_labels, 'fashion_mnist_test.tfrecords')
        
    if __name__ == '__main__':
        main()

    读取TFRecord数据来训练

    以下代码读取TFRecord数据用于训练,改代码改编自官方例程:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/how_tos/reading_data

    原始代码运行时报错,已修复。

    注意:在这个例子中,_, loss_value = sess.run([train_op, loss]),只执行一次Batch Input,无论[]中是什么,有多少个操作。

    import argparse
    import os.path
    import sys
    import time
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import mnist
    
    FLAGS = None
    
    TRAIN_FILE = 'fashion_mnist_train.tfrecords'
    VALIDATION_FILE = 'fashion_mnist_test.tfrecords'
    
    
    def decode(serialized_example):
        features = tf.parse_single_example(serialized_example,
                                           features={'image_raw': tf.FixedLenFeature([], tf.string),
                                                     'label':     tf.FixedLenFeature([], tf.int64)})
        image = tf.decode_raw(features['image_raw'], tf.uint8)
        image.set_shape((mnist.IMAGE_PIXELS))
        label = tf.cast(features['label'], tf.int32)
        return image, label
    
    
    def augment(image, label):
        """Placeholder for data augmentation."""
        # OPTIONAL: Could reshape into a 28x28 image and apply distortions here.
        return image, label
    
    
    def normalize(image, label):
        """Convert `image` from [0, 255] -> [-0.5, 0.5] floats."""
        image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
        return image, label
    
    
    def inputs(train, batch_size, num_epochs):
        """Reads input data"""
        if not num_epochs:
            num_epochs = None
        filename = os.path.join(FLAGS.train_dir, TRAIN_FILE if train else VALIDATION_FILE)
    
        with tf.name_scope('input'):
            dataset = tf.data.TFRecordDataset(filename)
            dataset = dataset.map(decode)
            dataset = dataset.map(augment)
            dataset = dataset.map(normalize)
            dataset = dataset.shuffle(1000 + 3 * batch_size)
            dataset = dataset.repeat(num_epochs)
            dataset = dataset.batch(batch_size)
            iterator = dataset.make_one_shot_iterator()
        return iterator.get_next()
    
    
    def run_training():
        with tf.Graph().as_default():
            image_batch, label_batch = inputs(train=True,
                                              batch_size=FLAGS.batch_size,
                                              num_epochs=FLAGS.num_epochs)
            logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2)
            loss = mnist.loss(logits, label_batch)
            train_op = mnist.training(loss, FLAGS.learning_rate)
            
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
                               
            with tf.Session() as sess:
                sess.run(init_op)
                try:
                    step = 0
                    while True:  # Train until OutOfRangeError
                        start_time = time.time()
                        _, loss_value = sess.run([train_op, loss])
                        duration = time.time() - start_time
                        if step % 100 == 0:
                            print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
                        step += 1
                except tf.errors.OutOfRangeError:
                    print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
    
    
    def main(_):
        run_training()
    
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument('--learning_rate', type=float, default=0.01, help='Initial learning rate.')
        parser.add_argument('--num_epochs',    type=int,   default=2,    help='Number of epochs to run trainer.')
        parser.add_argument('--hidden1',       type=int,   default=128,  help='Number of units in hidden layer 1.')
        parser.add_argument('--hidden2',       type=int,   default=32,   help='Number of units in hidden layer 2.')
        parser.add_argument('--batch_size',    type=int,   default=100,  help='Batch size.')
        parser.add_argument('--train_dir',     type=str,   default='./', help='Directory with the training data.')
        FLAGS, unparsed = parser.parse_known_args()
        tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

    参考了:

    • https://blog.csdn.net/gg_18826075157/article/details/78449104 
    • https://github.com/zalandoresearch/fashion-mnist/blob/master/utils/mnist_reader.py
  • 相关阅读:
    描述一下Spring Bean的生命周期
    BeanFactory和ApplicationContext有什么区别
    谈谈你对AOP的理解
    谈谈对IOC的理解
    线程池中线程复用原理
    线程池中阻塞队列的最用?为什么是先添加队列而不是先创建最大线程
    为什么使用线程池?解释下线程池参数
    去噪声论文阅读
    怎么使用有三AI完成系统性学习
    JavaCnn项目注解
  • 原文地址:https://www.cnblogs.com/xbit/p/10071848.html
Copyright © 2011-2022 走看看