zoukankan      html  css  js  c++  java
  • TFRecord及PAI上的第一个程序

    TFRcord的原理

    TFRecord是一种标准的Tensorflow格式,可以将任意的数据转换为TFRecord格式, 这种格式与网络应用架构相匹配,多线程的并行处理数据,速度快。TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。可以 将数据填入到Example协议内存块(protocol buffer),并将协议内存块序列化为一个字符串, 通过tf.python_io.TFRecordWriter class写入到TFRecords文件。

    读取TFRecords文件的数据, 使用tf.TFRecordReadertf.parse_single_example解析器。parse_single_exampleExample协议内存块(protocol buffer)解析为张量。

    MNIST数据集转化为TFRcord并进行读取

    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data
    import numpy as np
    from PIL import Image
    
    #把传入的value转化为整数型的属性,int64_list对应着 tf.train.Example 的定义
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    #把传入的value转化为字符串型的属性,bytes_list对应着 tf.train.Example 的定义
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    def mnsit_tfreords(images,labels,filename,num_examples):
        
        writer = tf.python_io.TFRecordWriter(filename)
    
        for index in range(num_examples):
            #把图像转化为字符串
            image_raw = images[index].tostring()
            #将一个图像转化为Example Protocol Buffer,并将所有的信息写入这个数据结构
            example = tf.train.Example(features=tf.train.Features(feature={
                'image_raw': _bytes_feature(image_raw),
                'label': _int64_feature(np.argmax(labels[index]))}))
            writer.write(example.SerializeToString())
    
        writer.close()
    
    def read_image(filename):
        reader = tf.TFRecordReader()
        #通过 tf.train.string_input_producer 创建输入队列
        filename_queue = tf.train.string_input_producer([filename])
        #从文件中读取一个样例
        _, serialized_example = reader.read(filename_queue)
        #解析读入的一个样例
        features = tf.parse_single_example(
            serialized_example,
            features={
                #这里解析数据的格式需要和上面程序写入数据的格式一致
                'image_raw': tf.FixedLenFeature([], tf.string),
                'label': tf.FixedLenFeature([], tf.int64),
            })
        #tf.decode_raw可以将字符串解析成图像对应的像素数组
        image = tf.decode_raw(features['image_raw'], tf.uint8)
        image = tf.reshape(image, [28, 28, 1])
        #tf.cast可以将传入的数据转化为想要改成的数据类型
        label = tf.cast(features['label'], tf.int32)
        return image,label
    def read_image_batch(filename):
        image,label = read_image(filename)
        num_preprocess_threads = 1
        batch_size = 128
        min_queue_examples = 100
        image_batch, label_batch = tf.train.shuffle_batch(
            [image, label],
            batch_size=batch_size,
            num_threads=num_preprocess_threads,
            capacity=min_queue_examples +  batch_size,
            min_after_dequeue=min_queue_examples)
        return image_batch,label_batch
    
    #读取TFRecord文件中的数据
    def read_tfrecords(filename):
        image_batch,label_batch = read_image_batch(filename)
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
            #启动多线程处理数据
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            for i in range(5):
                image, label = sess.run([image_batch, label_batch])
                result = Image.fromarray(image[0].reshape([28,28]))
                result.save(str(i) + '.png')
            coord.request_stop()
            coord.join(threads)
    if __name__ == '__main__':
        mnist = input_data.read_data_sets("D:/Git/MyCode/GAN/Datas/mnist", dtype=tf.uint8, one_hot=True)
        images = mnist.train.images
        labels = mnist.train.labels
        num_examples = mnist.train.num_examples
        filename = "./train.tfrecords"
        mnsit_tfreords(images,labels,filename,num_examples)
        images = mnist.test.images
        labels = mnist.test.labels
        num_examples = mnist.test.num_examples
        filename = "./test.tfrecords"
        mnsit_tfreords(images,labels,filename,num_examples)
        read_tfrecords("./test.tfrecords")
    View Code

    其中tf.train.batch是按顺序读取数据,队列中的数据始终是一个有序的队列,对头一直在补充,而tf.train.shuffle_batch是将队列中数据打乱后,再读取出来,因此队列中剩下的数据也是乱序的,capacity队列长度,读取的数据是基于这个范围的,在这个范围内,min_after_dequeue越大,数据越乱。

    PAI第一个程序Mnist分类

    将mnist数据集保存为tensorflow的标准形式,在PAI的OSS存储中也可以直接盗用tensorflow进行读取,比较方便。

    import os
    import tensorflow as tf
    import argparse
    
    FLAGS = None;
    
    
    def read_image(filename):
        reader = tf.TFRecordReader()
        filename_queue = tf.train.string_input_producer([filename])
        _, serialized_example = reader.read(filename_queue)
        features = tf.parse_single_example(
            serialized_example,
            features={
                'image_raw': tf.FixedLenFeature([], tf.string),
                'label': tf.FixedLenFeature([], tf.int64),
            })
        image = tf.decode_raw(features['image_raw'], tf.uint8)
        image = tf.reshape(image, [784])
        image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
        label = tf.cast(features['label'], tf.int32)
        return image,label
    
    def read_image_batch(filename,batch_size = 128):
        image,label = read_image(filename)
        num_preprocess_threads = 10
        batch_size = batch_size
        min_queue_examples = 100
        image_batch, label_batch = tf.train.shuffle_batch(
            [image, label],
            batch_size=batch_size,
            num_threads=num_preprocess_threads,
            capacity=min_queue_examples +  batch_size,
            min_after_dequeue=min_queue_examples)
        one_hot_labels = tf.to_float(tf.one_hot(label_batch, 10, 1, 0))
        return image_batch,one_hot_labels
    
    def main(_):
        train_file_path = os.path.join(FLAGS.buckets, "train.tfrecords")
        test_file_path = os.path.join(FLAGS.buckets, "test.tfrecords")
        ckpt_path = os.path.join(FLAGS.checkpointDir, "model.ckpt")
    
        train_images,train_labels = read_image_batch(train_file_path)
        test_images,test_labels = read_image_batch(test_file_path)
    
        W = tf.get_variable('weights', [784, 10],initializer = tf.random_normal_initializer(stddev = 0.02))
        B = tf.get_variable('biases', [10],initializer = tf.constant_initializer(0.0))
    
        x = tf.reshape(train_images,[-1,784])
        y = tf.to_float(train_labels)
        y_ = tf.matmul(x, W) + B
        
        cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_))
        train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
        x_test = tf.reshape(test_images, [-1, 784])
        y_pred = tf.matmul(x_test, W) + B
        y_test = tf.to_float(test_labels)
        correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_test, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        saver = tf.train.Saver()
        with tf.Session() as sess:
            tf.global_variables_initializer().run()
    
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        
            for i in range(1000):
                sess.run(train_step)
                if ((i+1) % 10 == 0):
                        print("step:", i + 1, "accuracy:", sess.run(accuracy))
    
            print("accuracy: " , sess.run(accuracy))
    
            save_path = saver.save(sess, ckpt_path)
            print("Model saved in file: %s" % save_path)
    
            coord.request_stop()
            coord.join(threads)
    if __name__ == '__main__':
        parser = argparse.ArgumentParser();
        parser.add_argument('--buckets', type=str, default='',help='input data path')
        parser.add_argument('--checkpointDir', type=str, default='',help='output model path')
        FLAGS, _ = parser.parse_known_args()
        tf.app.run(main=main)
    View Code

    *踩过的坑,PAI不能有中文,没有在本地调好在运行,还有tf.train.match_filenames_once这个函数去读取tfrecords文件,需要本地变量保存filenames,所以不能存在零时变量上,不然无法读取文件,可以之间使用不进行局部存储。

  • 相关阅读:
    neo4j 图数据库
    eclipse 当中,两种添加插件的方法 .
    ubuntu16.04如何添加用root用户登录图形界面
    遇到Linux系统安装时窗口过大,按钮点不到,该怎么解决
    hadoop 搭建3节点集群,遇到Live Nodes显示为0时解决办法
    VMware 虚拟机克隆 CentOS 6.5 之后,网络配置问题的解决方案
    scala(13)-----集合(Collection)-------元组
    scala(13)-----集合(Collection)-------Map(映射)
    scala(13)-----集合(Collection)-------Set(集合)
    scala(13)-----集合(Collection)-------列表
  • 原文地址:https://www.cnblogs.com/yutingmoran/p/8592561.html
Copyright © 2011-2022 走看看