zoukankan      html  css  js  c++  java
  • tensorflow文件读取

    1、知识点

    """
    注意:在tensorflow当中,运行操作具有依赖性
    
    1、CPU操作计算与IO计算区别:
            CPU操作:
                1、tensorflow是一个正真的多线程,并行的执行任务
                2、使用tfrecords对文件读取进行改善
                
            IO操作:
                1、一次性读取数据,消耗内存
                2、一次性进行训练
            
    2、队列API:        
            1、tf.FIFOQueue(capacity, dtypes, name='fifo_queue') 先进先出队列,按顺序出队列
                    capacity:整数。可能存储在此队列中的元素数量的上限
                    dtypes:DType对象列表。长度dtypes必须等于每个队列元素中的张量数,dtype的类型形状,决定了后面进队列元素形状
                    return:返回一个进队列操作
                            dequeue(name=None) #从队列获取一个数据
                            enqueue(vals, name=None) #将数据存放在队列
                            enqueue_many(vals, name=None):放入数据,其中vals列表或者元组
        
            2、tf.RandomShuffleQueue 随机出队列
            
    3、队列管理器:qr = tf.train.QueueRunner(Q,enqueue_ops=[en_q*2])
                qr.create_threads(sess,start=True)    #开启子线程    
    
    4、线程协调器:tf.train.Coordinator() ,线程协调员,实现一个简单的机制来协调一组线程的终止
                返回对象方法:
                    request_stop() 
                    should_stop() 检查是否要求停止
                    join(threads=None, stop_grace_period_secs=120)  等待线程终止
                    return:线程协调员实例
    
    5、CSV文件读取步骤:
            1、先找到文件,构造一个列表 
                file_name = os.listdir("./csvData/")
                file_list = [os.path.join(file) for file in file_name]
            2、构造文件列队
                
            3、构造阅读器,读取队列内容(一行)
            4、解码内容
            5、批处理(多个样本)
    
    6、文件读取API-文件队列构造:tf.train.string_input_producer(string_tensor,,shuffle=True) 将输出字符串(例如文件名)输入到管道队列
             参数:   
                string_tensor    含有文件名的1阶张量
                num_epochs:过几遍数据,默认无限过数据
                return:具有输出字符串的队列
                
    7、文件读取API-文件阅读器:根据文件格式,选择对应的文件阅读器
        a) class tf.TextLineReader() 阅读文本文件逗号分隔值(CSV)格式,默认按行读取
                return:读取器实例
        b) tf.FixedLengthRecordReader(record_bytes)要读取每个记录是固定数量字节的二进制文件
                record_bytes:整型,指定每次读取的字节数
                return:读取器实例
        c) tf.TFRecordReader    读取TfRecords文件
        共同的读取方法:read(file_queue):从队列中指定数量内容 ,返回一个Tensors元组(key文件名字,value默认的内容(行,字节))
    
    8、文件读取API-文件内容解码器:由于从文件中读取的是字符串,需要函数去解析这些字符串到张量
        a) tf.decode_csv(records,record_defaults=None,field_delim = None,name = None)  将CSV转换为张量,与tf.TextLineReader搭配使用
            records:tensor型字符串,每个字符串是csv中的记录行
            field_delim:默认分割符”,”
            record_defaults:参数决定了所得张量的类型,并设置一个值在输入字符串中缺少使用默认值,如
        b) tf.decode_raw(bytes,out_type,little_endian = None,name = None) 
            将字节转换为一个数字向量表示,字节为一字符串类型的张量,与函数tf.FixedLengthRecordReader搭配使用,二进制读取为uint8格式
    
    9、开启线程操作
        tf.train.start_queue_runners(sess=None,coord=None) 收集所有图中的队列线程,并启动线程
                sess:所在的会话中
                coord:线程协调器
                return:返回所有线程队列
    
    9、管道读端批处理:
        a) tf.train.batch(tensors,batch_size,num_threads = 1,capacity = 32,name=None) 读取指定大小(个数)的张量
                tensors:可以是包含张量的列表
                batch_size:从队列中读取的批处理大小
                num_threads:进入队列的线程数
                capacity:整数,队列中元素的最大数量
                return:tensors
        b) tf.train.shuffle_batch(tensors,batch_size,capacity,min_after_dequeue,num_threads=1,) 乱序读取指定大小(个数)的张量
                min_after_dequeue:留下队列里的张量个数,能够保持随机打乱
    
    10、错误
        OutOfRangeError (see above for traceback): FIFOQueue '_1_batch/fifo_queue' is closed and has insufficient elements (requested 9, current size 0)
        解决方法:由于从上可知需要9个数据,但是读取为0,因此可能是数据有问题,即数据文件或者读取路径有问题
    """

    2、代码

    # coding = utf-8
    
    import tensorflow as tf
    import  os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    def readCSVFile(filelist):
        """
        读取CSV文件
        :param filelist:  文件路径+名字列表
        :return: 读取的内容
        """
        #1、构造文件队列
        file_queue = tf.train.string_input_producer(file_list)
        #2、构造CSV阅读器读取队列数据(读一行)
        reader = tf.TextLineReader()
        key , value = reader.read(file_queue)
    
        #3、对每行内容解码
        #record_defaults:指定每一个样本的每一列的类型,指定默认值[["None"], [4.0]]
        records =[["None"],["None"]]
        example , label = tf.decode_csv(value,record_defaults=records)
        #批处理大小跟队列、数据的数量没有影响,只决定这批次取多少数据batch_size
    
        ##############批处理####################
       #读取多个数据,就需要使用批处理
        example_batch,label_batch = tf.train.batch([example,label],batch_size=20,num_threads=1,capacity=90)
        print(example_batch, label_batch)
        return example_batch,label_batch
        #return example,label
    
    
    ############队列################
    def queue():
        #1、首先定义数据
        Q = tf.FIFOQueue(3,tf.float32)
    
        #2、放入数据
        enq_many = Q.enqueue_many([[0.1,0.2,0.3],])
    
        #定义一些数据处理的逻辑
        out_q = Q.dequeue()
        out_q = out_q + 1
        en_q = Q.enqueue(out_q)
    
        #运行会话
        with tf.Session() as sess:
            #初始化队列
            sess.run(enq_many)
            #处理数据
            for i in range(100):
                sess.run(en_q)
            for i in range(Q.size().eval()):
                print(sess.run(Q.dequeue()))
        return None
    
    
    #############异步执行#########################
    def unasynQueue():
        """
        异步读取
        :return:
        """
        #1、定义一个队列,1000
        Q = tf.FIFOQueue(1000,tf.float32)
        #2、定义要做的事,并放入队列中
        var = tf.Variable(0.0)
        #实现自增
        data = tf.assign_add(var,tf.constant(1.0))
        en_q = Q.enqueue(data)
        #3、定义队列管理器,指定多少个子线程,子线程做事
        qr = tf.train.QueueRunner(Q,enqueue_ops=[en_q,]*2)
    
        #初始化变量OP
        init_op = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init_op)
    
            # 开启线程管理器
            coord = tf.train.Coordinator()
            #开启子线程
            threads = qr.create_threads(sess,coord,start=True)
            #主线程不断读取数据
            for i in range(300):
                print(sess.run(Q.dequeue()))
    
            #回收线程
            coord.request_stop()
            coord.join(threads)
        return  None
    
    if __name__ == '__main__':
        file_name = os.listdir("./csvData/")
        file_list = [os.path.join("./csvData/",file) for file in file_name]
        example_batch ,label_batch = readCSVFile(file_list)
        with tf.Session() as sess:
            # #定义一个线程协调器
            coord = tf.train.Coordinator()
            # #开启读文件的线程
            threads = tf.train.start_queue_runners(sess,coord=coord)
            #打印读取的内容
            print(sess.run([example_batch,label_batch]))
            #回收线程
            coord.request_stop()
            coord.join(threads)
  • 相关阅读:
    日常记Bug
    Docker部署Django
    杂记:防火墙、企业微信登陆、RestFrameWork
    Python2和Python3的编码
    杂记:Django和static,Nginx配置路径,json_schema
    xlwt模块的使用
    企业微信登陆
    markdown八条基础语法
    webstorm 添加markdown支持
    【electron系列之二】复制图片
  • 原文地址:https://www.cnblogs.com/ywjfx/p/10919211.html
Copyright © 2011-2022 走看看