zoukankan      html  css  js  c++  java
  • Tensorflow高效读取数据

    关于Tensorflow读取数据,官网给出了三种方法:

    • 供给数据(Feeding): 在TensorFlow程序运行的每一步, 让Python代码来供给数据。
    • 从文件读取数据: 在TensorFlow图的起始, 让一个输入管线从文件中读取数据。
    • 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)。

    在使用Tensorflow训练数据时,第一步为准备数据,现在我们只讨论图像数据。其数据读取大致分为:原图读取、二进制文件读取、tf标准存储文件读取。

    一、原图文件读取

           很多情况下我们的图片训练集就是原始图片本身,并没有像cifar dataset那样存成bin等格式。我们需要根据一个train_list列表,去挨个读取图片。这里我用到的方法是首先获取image list和labellist,然后读入队列中,那么对每次dequeue的内容中可以提取当前图片的路劲和label。

    1、获取文件列表

    def get_image_list(fileDir):
        imageList = []
        labelList = []
    
        filelist = os.listdir(fileDir)
        
        for var in filelist:
            imagename = os.path.join(fileDir, var)
            label = int(os.path.basename(var).split('_')[0])
            imageList.append(imagename)
            labelList.append(label)
            
        return imageList, labelList

    上述程序是从指定目录中获取文件列表和标签,其我的文件为

    image

    总共15个文件,’_’前为文件标签,记得要转化为int类型,否则后面程序或报错。

    2、将文件列表加载到内存列表中,并进行读取

    步骤分为:

    a、列表转化为tensor类型,并存到内存中

    b、 从内存列表中读取数据,进行获取图像和label

    c、 根据训练要求对数据进行转化

    d、利用batch获取批次文件

    def input_data_imageslist_slice(fileDir):
        
    #    获取文件列表
        imageList , labelList = get_image_list(fileDir)
        
    #    将文件列表和标签列表转为为tensor,进而能存入内存列表中,记得label在上面转为int,否则下面会出错,这是相对应的
        imagesTensor = tf.convert_to_tensor(imageList, dtype = tf.string)
        labelsTensor = tf.convert_to_tensor(labelList, dtype = tf.uint8)
        
    #   从内存列表中读取文件,此处只读取一个文件,并记录文件位置
        queue = tf.train.slice_input_producer([imagesTensor, labelsTensor])
        
    #    提取图片内容和标签内容,一定注意数据之间的转化;
        image_content = tf.read_file(queue[0])
        imageData = tf.image.decode_jpeg(image_content,channels=3)   #channels必须要制定,当时没指定,程序报错
        imageData = tf.image.convert_image_dtype(imageData,tf.uint8)    # 图片数据进行转化,此处为了显示而转化
        labelData = tf.cast(queue[1],tf.uint8)
    
    #    show_single_data(imageData, labelData)
        #根据数据训练尺寸,调整图片大小,此处设置为32*32
        new_size = tf.constant([IMAGE_WIDTH,IMAGE_WIDTH], dtype=tf.int32)
        image = tf.image.resize_images(imageData, new_size)
        
    #   这是数据提取关键,因为设置了batch_size,决定了每次提取数据的个数,比如此处是3,则每次为3个文件
        imageBatch, labelBatch = tf.train.shuffle_batch([image, labelData], batch_size = BATCH_SIZE,
                                                        capacity = 2000,min_after_dequeue = 1000)
        
        
        return imageBatch, labelBatch

    3、文件测试

    在文件测试中,必须添加 threads = tf.train.start_queue_runners(sess = sess),会话窗口才会从内存堆栈中读取数据。

    def test_record(filename):
        image_batch, label_batch = input_data_imageslist_slice(filename)
        
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            threads = tf.train.start_queue_runners(sess = sess)
            for i in range(5):
                val, label = sess.run([image_batch, label_batch])
                print(val.shape, label)

    我在程序中设置了batchsize为3,所以shape为(3,32,32,)后面则是label,此处很好的读取了数据

    (3, 32, 32, 3) [12  3 22]
    (3, 32, 32, 3) [ 7  2 22]
    (3, 32, 32, 3) [ 2  5 15]
    (3, 32, 32, 3) [12  5 13]
    (3, 32, 32, 3) [12  6 10]

    4、校验对比

    为了更好的得知shuffle_batch是否让文件和label对应,程序中进行了修改

    image = tf.image.resize_images(imageData, new_size)

    修改为:

    image = tf.cast(queue[0],tf.string)

    还有

    print(val.shape, label)
    修改为
    print(val, label)

    结果为:

    [b'E:\010_test_tensorflow\02_produce_data\images1\1_1.jpg'
     b'E:\010_test_tensorflow\02_produce_data\images1\12_03.jpg'
     b'E:\010_test_tensorflow\02_produce_data\images1\10_2.jpg'] [ 1 12 10]
    [b'E:\010_test_tensorflow\02_produce_data\images1\2_2.jpg'
     b'E:\010_test_tensorflow\02_produce_data\images1\22_9.jpg'
     b'E:\010_test_tensorflow\02_produce_data\images1\33_0.jpg'] [ 2 22 33]
    [b'E:\010_test_tensorflow\02_produce_data\images1\33_0.jpg'
     b'E:\010_test_tensorflow\02_produce_data\images1\7_0.jpg'
     b'E:\010_test_tensorflow\02_produce_data\images1\13_06.jpg'] [33  7 13]

    我们发现不但image和label相对应,而且还打乱了顺序,真的是很完美啊。

    二、TFRecords读取

             对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,大家自己尝试一下吧,我就不赘述了)。但是,如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用tensorflow提供的队列queue,也就是第二种方法 从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecords。

    FRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件。

            TFRecords文件包含了tf.train.Example 协议内存块(protocol buffer)(协议内存块包含了字段 Features)。我们可以写一段代码获取你的数据, 将数据填入到Example协议内存块(protocol buffer),将协议内存块序列化为一个字符串, 并且通过tf.python_io.TFRecordWriter 写入到TFRecords文件。

            从TFRecords文件中读取数据, 可以使用tf.TFRecordReadertf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。

    2.1 生成TFRecords文件

    class SaveRecord(object):
        def __init__(self,recordDir, fileDir, imageSize):        
            self._imageSize = imageSize
            
            trainRecord = os.path.join(recordDir,'train.tfrecord')
            validRecord = os.path.join(recordDir,'valid.tfrecord')
            
    #        获取文件列表
            filenames = os.listdir(fileDir)
            np.random.shuffle(filenames)
            fileNum = len(filenames)
            print('the count of images is ' + str(fileNum))
            
    #        获取训练和测试样本,比例为4:1
            splitNum = int(fileNum * 0.8)
            trainImages = filenames[ : splitNum]
            validImages = filenames[splitNum : ]
    
    #       保存数据到制定位置
            self.save_data_to_record( fileDir = fileDir, datas = trainImages, recordname = trainRecord)
            self.save_data_to_record(fileDir = fileDir,datas = validImages, recordname = validRecord)
               
        def save_data_to_record(self,fileDir, datas, recordname):
            writer = tf.python_io.TFRecordWriter(recordname)
            
            for var in datas:
                filename = os.path.join(fileDir, var)
                label = int(os.path.basename(var).split('_')[0])
                image = Image.open(filename)                # 打开图片
                image = image.resize((self._imageSize,self._imageSize))
                imageArray = image.tobytes()               #转为bytes
                
                example = tf.train.Example(features = tf.train.Features(feature = {
                          'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [imageArray]))
                          ,'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label]))}))
                writer.write(example.SerializeToString())
                
            writer.close()

    编程为一个类,其核心代码在于save_data_to_records,其主要流程为:

    • 初始化写入器writer,来源于tf.python_io.TFRecordWriter。
    • 遍历传入的数据,可以为文件名,意味后面二进制解析也是文件名
      • 解析文件名,获取label,这是之前处理好的
      • 利用IPL的Image读入图像数据,预处理数据:调整大小,且转为化二值化数据
      • 利用tf中的Example中获取数据,原理是利用字典对应关系,获取features,当然里面有点绕,仔细读读全是在类型转化而已
      • example二进制化,然后写入。
    • 关闭写入器

    其中里面关键点:图片bytes的转化,以及example的赋值。

             基本的,一个Example中包含FeaturesFeatures里包含Feature(这里没s)的字典。最后,Feature里包含有一个 FloatList, 或者ByteList,或者Int64List。

    2.2 读取record文件

    for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"):
        example = tf.train.Example()
        example.ParseFromString(serialized_example)
    
        image = example.features.feature['image'].bytes_list.value
        label = example.features.feature['label'].int64_list.value
        # 可以做一些预处理之类的
        print image, label

    上面为一个解析文件的一个例子,主要是利用example直接进行解析,简单,但是这样比较耗内存,常用的方法是利用文件队列读取。

    即利用string_input_produce,结合tf.recordreader进行数据读取,最后进行解析,其例子为:

    def read_and_decode(filename):
        #根据文件名生成一个队列
        filename_queue = tf.train.string_input_producer([filename])
    
        reader = tf.TFRecordReader()
        _, serialized = reader.read(filename_queue)   #返回文件名和文件
    
        features = tf.parse_single_example(serialized = serialized, features = {
            'image' : tf.FixedLenFeature([], tf.string),
            'label' : tf.FixedLenFeature([], tf.int64)})
         
        image = tf.decode_raw(features['image'], tf.uint8)
        image= tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3])
    #    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    #    image = tf.cast(features['image'], tf.string)
        label = tf.cast(features['label'], tf.int32)
        
        img_batch, label_batch = tf.train.shuffle_batch([image, label],
                                                    batch_size=BATCH_SIZE, capacity=2000,
                                                    min_after_dequeue=1000)
    
        return img_batch, label_batch

    其流程为:

    • record文件放入文件队列中
    • 初始化话RecordReader,发现这个reader和writer初始化方式不一样
    • 从队列中读取数据,记得返回两个值,我们只要第二个,此时数据为二进制数据(前面我们存入的二进制数据)
    • 根据约定解析数据,类型即为之前存储的格式。
    • 利用tf的转化获取image和label
    • 关键一步就是tf.train.shuffle_batch,利用此函数可以批量获取数据,当然是在文件列表中。

    此处对文件列表中数据读取的过程中,我们发现读取器是不一样的。比如此次是读取record的内存文件,代码为:

    filename_queue = tf.train.string_input_producer([filename])
    
        reader = tf.TFRecordReader()
        _, serialized = reader.read(filename_queue)   #返回文件名和文件

    而之前从文件列表和数据列表读取的时候为:

    imageList , labelList = get_image_list(fileDir)
        queue = tf.train.string_input_producer(imageList)
        
        reader = tf.WholeFileReader()
        _, image_content = reader.read(queue)

    而我们使用的slice_input_produce的时候,变成了tf.read_file,一定记得各个的不同。

    #   从内存列表中读取文件,此处只读取一个文件,并记录文件位置
        queue = tf.train.slice_input_producer([imagesTensor, labelsTensor])
        
    #    提取图片内容和标签内容,一定注意数据之间的转化;
        image_content = tf.read_file(queue[0])
        imageData = tf.image.decode_jpeg(image_content,channels=3)

    2.3 测试数据

    前面我们解析了shuffle_batch的好处,此处我们即检测是否读取了数据。

    def test_record(filename):
        image_batch, label_batch = read_and_decode(filename)
        
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            threads = tf.train.start_queue_runners(sess = sess)
            for i in range(5):
                val, label = sess.run([image_batch, label_batch])
                print(val.shape, label)

    此时的输出结果为:

    (10, 16, 16, 3) [15  7 33  5  4 10 13  7  1  3]
    (10, 16, 16, 3) [33 22 10 10  4 12  7  4 13 10]
    (10, 16, 16, 3) [10 10 12  7  5 15 12 22 15  5]
    (10, 16, 16, 3) [10 10  5  1 12 10  3  5 33  3]
    (10, 16, 16, 3) [12  4  7 15  4  7  4 13  5 10]

    结果表明有效的对数据进行了读取。

  • 相关阅读:
    MongoDB简单使用
    mongodb安装部署
    分布式通信-序列化
    分布式通信协议
    分布式概念
    springboot-事件
    spring-事件
    spring-@Component/@ComponentScan注解
    springboot-Date日期时间问题
    enginx:基于openresty,一个前后端统一,生态共享的webstack实现
  • 原文地址:https://www.cnblogs.com/polly333/p/7489699.html
Copyright © 2011-2022 走看看