zoukankan      html  css  js  c++  java
  • 采用tfrecord形式读写训练数据

    tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。尤其在面对海量数据时,使用常用的内存读取方式变得不切实际,tfrecored方式为我们带来了更大的便捷,同时还可以配合shuffe大大提高model的train效率。

    示例代def convert_tfrecord(data, label):

    """保存为tfrecord形式
        :param data:
        :param label:
        :return:
        """
        record_path = './resources/train.tfrecord'
        # 调用example和features函数将数据格式化保存起来
        cnt = 0
        writer = tf.python_io.TFRecordWriter(record_path)
        for d, s, l in zip(data[0], data[1], label):
            if cnt % 100 == 0:
                print('write example {}'.format(cnt))
            cnt += 1
            example = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=d)),
                        'score': tf.train.Feature(float_list=tf.train.FloatList(value=s)),
                        'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[l]))
                    }
                )
            )
    
            writer.write(example.SerializeToString())
        writer.close()
        print('写入ok')
    
        # 读取,batch 取
        filename_queue = tf.train.string_input_producer([record_path],)
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
    features
    = tf.io.parse_single_example(serialized_example, features={ 'sample': tf.io.FixedLenFeature([9], tf.int64), 'score': tf.io.FixedLenFeature([9], tf.float32), 'label': tf.io.FixedLenFeature([1], tf.int64), }) is_batch = True if is_batch: batch_size = 3 min_after_dequeue = 10 capacity = min_after_dequeue + 3 * batch_size samples, scores, labels = tf.train.shuffle_batch([features['sample'], features['score'], features['label']], batch_size=batch_size, num_threads=3, capacity=capacity, min_after_dequeue=min_after_dequeue) with tf.compat.v1.Session() as sess: init_op = tf.initialize_all_variables() sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(1000): # 从会话中取出数据 sample, score, label = sess.run([samples, scores, labels]) print(sample) print(score) print('###########') coord.request_stop() coord.join(threads) print('ok')
  • 相关阅读:
    Ubuntu配置sublime text 3的c编译环境
    ORA-01078错误举例:SID的大写和小写错误
    linux下多进程的文件拷贝与进程相关的一些基础知识
    ASM(四) 利用Method 组件动态注入方法逻辑
    基于Redis的三种分布式爬虫策略
    Go语言并发编程总结
    POJ2406 Power Strings 【KMP】
    nyoj 会场安排问题
    Server Tomcat v7.0 Server at localhost was unable to start within 45 seconds. If the server requires more time, try increasing the timeout in the server editor.
    Java的String、StringBuffer和StringBuilder的区别
  • 原文地址:https://www.cnblogs.com/demo-deng/p/13789061.html
Copyright © 2011-2022 走看看