import tensorflow as tf import os ''' tensorflow读取文件的流程,每一步每一种数据都有对应封装好的API进行处理: 1、构造一个文件队列:A文件,B文件,C文件,每个文件内有100个样本 2、读取队列内容:一个一个样本读取,二进制文件就是指定一个样本的byte读取,图片就是一张一张 3、进行解码 4、批处理:将样本一个一个的放入一个队列中,达到一定数量后,一次性进行处理 ''' def readcsv(filelist): """读取csv文件""" # 1.构造文件队列 file_queue = tf.train.string_input_producer(filelist) # 2.构造CSV阅读器读取队列数据(按一行),read返回一个元组,一个是路径一个是样本内容 reader = tf.TextLineReader() key, value = reader.read(file_queue) # 对每行内容进行解码,field_delim分隔符默认“,”,record_defaults指定每一个样本的每一列类型,并设置默认值对缺失值进行填充 # CSV数据中有几列就应该有几个列表,"None"表示字符串并同时指定默认值是None,1表示int类型,并同时指定默认值是1,如果是4.5则是float类型,默认值就是4.5 # decode_csv返回的是每一个样本每一个的值,返回的是op列表 records = [["None"],[1]] example, label = tf.decode_csv(value, field_delim=',', record_defaults=records) # 批处理,batch_size从队列读取的批处理大小,num_threads使用几个线程处理,capacity批处理队列大小,tf.train.batch返回的是两个元素的op,一个op存储着一列九行数据 example_batch, label_batch = tf.train.batch([example, label], batch_size=5, num_threads=1, capacity=10) return example_batch, label_batch if __name__ == "__main__": # 构造文件列表 file_name = os.listdir("./data/csvdata") filelist = [os.path.join("./data/csvdata", file) for file in file_name] example_batch, label_batch = readcsv(filelist) # 开启会话 with tf.Session() as sess: # 定义线程协调器 coord = tf.train.Coordinator() # 开启读取文件的线程 thd = tf.train.start_queue_runners(sess, coord=coord, start=True) # 打印读取内容 print(sess.run([example_batch, label_batch])) # 回收子线程 coord.request_stop() coord.join(thd)