zoukankan      html  css  js  c++  java
  • Tensorflow多线程输入数据处理框架

    图像预处理方法可以减少无关因素对图像识别模型效果的影响,但会减慢整个训练过程。为了避免图像预处理成为神经网络模型训练效率的瓶颈,tensorflow提供了一套多线程处理输入数据的框架。本博客将详细介绍这个框架。


    队列和多线程

    队列不仅是一种数据结构,也提供了多线程机制。tensorflow提供了Enqueue、EnqueueMany和Dequeue三种方式修改队列的状态。Tensorflow提供了FIFOQueue和RandomShuffleQueue两种队列。

            队列不仅是一种数据结构,还是异步计算张量取值的一个重要机制,比如多个线程可以同时向一个队列中写元素,或者读取队列中的元素。Tensorflow提供了tf.Coordinator和tf.QueueRunner两个类来完成多线程协同功能。

            tf.Coordinator主要协同多个线程一起停止,并提供should_stop和request_stop和join三个函数。启动的线程需要一直查询tf.Coordinator提供的should_stop函数,当这个函数返回True时,则当前线程需要退出。每一个启动的线程都可以调用request_stop函数来通知其他线程退出。当某一个线程调用request_stop函数后,should_stop返回值将被设置为True,其他线程可以同时终止。

        tf.QueueRunner主要用于启动多个线程来操作同一个队列。在使用tf.train.QueueRunner时,需要明确调用tf.train.start_queue_runners来启动所有线程。

    在tensorflow中,队列和变量类似,都是计算图上有状态的节点。其他节点可以修改它的状态。以下程序展示了如何使用这些函数来操作一个队列。

    #coding :utf-8
    
    import tensorflow as tf
    
    #创建一个先进先出队列,指定队列中最多可以保存两个元素,并指定数据类型
    q = tf.FIFOQueue(2,'int32')
    #使用queue_many函数来初始化队列中的元素。
    # 和变量初始化类似,在队列使用之前需要明确的调用这个初始化过程
    init = q.enqueue_many(([0,10],))
    #使用dequeue函数将队列中的第一个元素出队列
    x = q.dequeue()
    y = x + 1
    #将y的值重新加入队列
    q_inc = q.enqueue([y])
    
    with tf.Session() as sess:
        init.run()
        for _ in range(5):
            v,_ = sess.run([x,q_inc])
            print(v)
    
    '''
    输出:0
         10
         1
         11
         2
    '''

    输入文件队列

    tensorflow提供创建队列的两种方式:tf.train.string_input_producer()tf.train.slice_input_producer()

    tf.train.slice_input_producer([image,label],num_epochs=10),随机产生一个图片和标签,num_epochs=10,则表示把所有的数据过10遍,使用完所有的图片数据为一个epoch,这是重复使用10次。上面的用法表示你的数据集和标签已经全部加载到内存中了,如果数据集非常庞大,我们通过这个函数也可以只加载图片的路径,放入图片的path。

    # -*- coding: utf-8 -*-
    
    import tensorflow as tf
    import glob
    import matplotlib.pyplot as plt
    import time
     
    datapath=r'img/'
    imgpath = glob.glob(datapath+'*.jpg')
    # 将路径转化成张量形式
    imgpath = tf.convert_to_tensor(imgpath)
     
    # 产生一个队列每次随机产生一张图片地址
    # 注意这里要放在数组里面
     
    image = tf.train.slice_input_producer([imgpath])
    # 得到一个batch的图片地址
    # 由于tf.train.slice_input_producer()函数默认是随机产生一个实例
    # 所以在这里直接使用tf.train.batch()直接获得一个batch的数据即可
    # 没有必要再去使用tf.trian.shuffle_batch() 速度会慢
    img_batch = tf.train.batch([image],batch_size=20,capacity=100)
     
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        thread = tf.train.start_queue_runners(sess,coord)
        i = 0
        try:
            while not coord.should_stop():
                imgs = sess.run(img_batch)
                print(imgs)
                #fig = plt.figure()
                for i,path in enumerate(imgs):
                    img = plt.imread(path[0].decode('utf-8'))
                    #axes = fig.add_subplot(5,4,i+1)
                    #axes.imshow(img)
                    #axes.axis('off')
                #plt.ion()
                #plt.show()
                time.sleep(1)
                #plt.close()
                i+=1
                if i%10==0:
                    break
        except tf.errors.OutOfRangeError:
            pass
        finally:
            coord.request_stop()
        coord.join(thread)
     
    
    '''
    输出:
    [[b'img\2.jpg']
     [b'img\3.jpg']
     [b'img\7.jpg']
     [b'img\5.jpg']
     [b'img\1.jpg']
     [b'img\4.jpg']
     [b'img\8.jpg']
     [b'img\6.jpg']]
    '''
     
     
    

    综合训练数据

    tensorflow提供了tf.train.batch和tf.train.shuffle函数将单个的样例组织成batch的形式输出。这两个函数都会产生一个队列,队列的入队操作是生成单个样例的方法,而每次出队得到的是一个batch的样例。他们唯一的区别在于是否会将数据顺序打乱。

    #%%
    import tensorflow as tf
    import numpy as np
    import os
     
    # 
    img_width = 208
    img_height = 208
     
     
    #%% 获取图片 及 生成标签
    train_dir = 'img/'
     
    def get_files(file_dir):
        '''
        args:
            file_dir: file directory
        Returns:
            ist of images and labels
        '''
        cats = []
        label_cats = []
        dogs = []
        label_dogs = []
        for file in os.listdir(file_dir): # 获取当前目录下的所有文件和目录名
            name = file.split('.') #分割字符段,返回name为一个列表
            if name[0] == 'cat':
                cats.append(file_dir + file)
                label_cats.append(0)
            else:
                dogs.append(file_dir + file)
                label_dogs.append(1)
        print('There are %d cats 
    There are %d dogs' %(len(cats), len(dogs)))
        
        image_list = np.hstack((cats, dogs)) ## 将图像堆叠在一起
        label_list = np.hstack((label_cats, label_dogs)) ## 将图像标签堆叠在一起
        
        temp = np.array([image_list, label_list]) # 将文件名和标签对应起来
        temp = temp.transpose() #矩阵转置
        np.random.shuffle(temp) # 打乱存放的顺序
        
        # 先集合起来打乱在分开的目的是为了获取打乱后的图形及其对应的标签
        image_list = list(temp[:, 0]) # 获取图片
        label_list = list(temp[:, 1]) # 获取标签
        label_list = [float(i) for i in label_list]
        
        return image_list, label_list
     
    #%%
    # 对图片进行裁剪
    def get_batch(image, label, image_W, image_H, batch_size, capacity):
        '''
        args:
            image: list type
            label: list type
            image_W: image_width
            image_H: image_Height
            batch_size:batch size #每批次的图像量
            capacity: the maxmum elements in queue
        Returns:
            image_batch: 4D tensor [batch_size, width, height, 3],dtype=tf.float32
            label_batch: 1D tensor [batch_size], dtype = tf.float32
        '''
        # 类型转换函数,返回张量
        image = tf.cast(image, tf.string) # 数据类型转换 image->string
        label = tf.cast(label, tf.int32)  # 数据类型转换 label->int32
        
        # make an input queue 生成输入对列
        input_queue = tf.train.slice_input_producer([image, label])
        
        label = input_queue[1] # 读取标签
        image_contents = tf.read_file(input_queue[0]) # 读取图像 string类型
        image = tf.image.decode_jpeg(image_contents, channels = 3) #解码
     
        ########################################
        # data argumentatioan should go to here
        ########################################
        # 对图片进行裁剪或扩充【在图像中心处裁剪】,统一大小
        image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
        # 数据标准化 训练前需要对数据进行标准化
        image = tf.image.per_image_standardization(image) 
        # 生成批次 在输入的tensor中创建一些tensor数据的batch
        image_batch, label_batch = tf.train.batch([image, label],
                                                  batch_size = batch_size,
                                                  num_threads = 64,
                                                  capacity = capacity) 
        # 重新生成大小,即将label_batch变换成[batch_size]行的形式
        label_batch = tf.reshape(label_batch, [batch_size]) 
        
        return image_batch, label_batch
        
    #%% test :  matplotlib.pyplot绘图 绘制直线、条形/矩形区域
     
    import matplotlib.pyplot as plt
     
    BATCH_SIZE = 5 # 批次中的图像数量
    CAPACITY = 256 # 队列中最多容纳元素的个数
    IMG_W = 208
    IMG_H = 208
     
    train_dir = 'data/train/'
     
    image_list, label_list = get_files(train_dir)
    image_batch, label_batch = get_batch(image_list, label_list, IMG_W, IMG_H,
                                        BATCH_SIZE, CAPACITY)
     
    with tf.Session() as sess:
        print("start")
        i = 0
        # 开始输入队列监控,启动多线程处理数据
        coord = tf.train.Coordinator() # 
        threads = tf.train.start_queue_runners(coord = coord) # 启动入队线程
        
        try:
            while not coord.should_stop() and i<1:
                
                img, label = sess.run([image_batch, label_batch])# 输入list结构
                
                # just test one batch
                # arange返回一个array对象([ ])
                for j in np.arange(BATCH_SIZE):
                    print('label: %d'%label[j])
                    plt.imshow(img[j,:,:,:]) 
                    plt.show()
                i += 1
        except tf.errors.OutOfRangeError:
            print('done!')
        finally:
            print('finished')
            coord.request_stop() # 通知其它线程关闭
        coord.join(threads) # 其他线程关闭之后,这一函数才能返回
     
    '''
    输出:
    (略)
    '''
    

    参考文献:Tensorflow 实战Google深度学习框架 郑泽宇版

                     https://www.w3cschool.cn/tensorflow_python/tensorflow_python-caw628sg.html

    天上我才必有用,千金散尽还复来!
  • 相关阅读:
    List<Object> 查询解析优化
    hibernate 中 query.list()的优化
    移动端屏幕宽度自适应原理及实现
    js获取用户当前地理位置(省、市、经纬度)
    mescroll.js简单的上拉加载、下拉刷新插件,带完整注释
    Web前端性能优化总结——如何提高网页加载速度
    浏览器渲染页面的原理及流程
    优酷1080p的kux格式文件怎么转换为MP4格式?
    js处理文章详情页点击量统计
    plupload上传视频插件jQuery+php
  • 原文地址:https://www.cnblogs.com/lutaishi/p/13436325.html
Copyright © 2011-2022 走看看