zoukankan      html  css  js  c++  java
  • 采用tfrecord形式读写训练数据

    tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。尤其在面对海量数据时,使用常用的内存读取方式变得不切实际,tfrecored方式为我们带来了更大的便捷,同时还可以配合shuffe大大提高model的train效率。

    示例代def convert_tfrecord(data, label):

    """保存为tfrecord形式
        :param data:
        :param label:
        :return:
        """
        record_path = './resources/train.tfrecord'
        # 调用example和features函数将数据格式化保存起来
        cnt = 0
        writer = tf.python_io.TFRecordWriter(record_path)
        for d, s, l in zip(data[0], data[1], label):
            if cnt % 100 == 0:
                print('write example {}'.format(cnt))
            cnt += 1
            example = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        'sample': tf.train.Feature(int64_list=tf.train.Int64List(value=d)),
                        'score': tf.train.Feature(float_list=tf.train.FloatList(value=s)),
                        'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[l]))
                    }
                )
            )
    
            writer.write(example.SerializeToString())
        writer.close()
        print('写入ok')
    
        # 读取,batch 取
        filename_queue = tf.train.string_input_producer([record_path],)
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
    features
    = tf.io.parse_single_example(serialized_example, features={ 'sample': tf.io.FixedLenFeature([9], tf.int64), 'score': tf.io.FixedLenFeature([9], tf.float32), 'label': tf.io.FixedLenFeature([1], tf.int64), }) is_batch = True if is_batch: batch_size = 3 min_after_dequeue = 10 capacity = min_after_dequeue + 3 * batch_size samples, scores, labels = tf.train.shuffle_batch([features['sample'], features['score'], features['label']], batch_size=batch_size, num_threads=3, capacity=capacity, min_after_dequeue=min_after_dequeue) with tf.compat.v1.Session() as sess: init_op = tf.initialize_all_variables() sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(1000): # 从会话中取出数据 sample, score, label = sess.run([samples, scores, labels]) print(sample) print(score) print('###########') coord.request_stop() coord.join(threads) print('ok')
  • 相关阅读:
    EventBus--介绍
    EventBus--出现的问题
    File存对象--android 的File存储到SD卡();
    SharePrecences--(json+sharePrecences)存list 或对象
    缓存AsimpleCache -- 解决Android中Sharedpreferences无法存储List数据/ASimpleCache
    ViewPager--左右可滑动的
    git之win安装git和环境配置及常用命令总结
    mySql事务_ _Java中怎样实现批量删除操作(Java对数据库进行事务处理)?
    eclispe---快捷键设置
    bug_ _org.json.JSONException: End of input at character 0 of
  • 原文地址:https://www.cnblogs.com/demo-deng/p/13789061.html
Copyright © 2011-2022 走看看