zoukankan      html  css  js  c++  java
  • tfrecords转np.array

    import tensorflow as tf
    import numpy as np
    from keras.utils import to_categorical
    import sys
    
    
    def tfrecord2array(path_res):
        imgs = []
        lbls = []
        # print('tfrecords_files to be transformed:', path_res)
        reader = tf.TFRecordReader()
    
        filename_queue = tf.train.string_input_producer([path_res], num_epochs=1)
    
        # 从 TFRecord 读取内容并保存到 serialized_example 中
        _, serialized_example = reader.read(filename_queue)
        # 读取 serialized_example 的格式
        features = tf.parse_single_example(
            serialized_example,
            features={
                'image_raw': tf.FixedLenFeature([], tf.string),
                'label': tf.FixedLenFeature([], tf.int64),
            })
    
        # 解析从 serialized_example 读取到的内容
        labels = tf.cast(features['label'], tf.int64)
        images = tf.decode_raw(features['image_raw'], tf.uint8)
    
        # print('Extracting {} has just started.'.format(path_res))
        with tf.Session() as sess:
            # 启动多线程
            sess.run(tf.local_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)
            while not coord.should_stop():
                try:
                    label, img = sess.run([labels, images])
                except tf.errors.OutOfRangeError:
                    print("Turn to next folder.")
                    break
                img = (img > 0).astype(np.uint8).reshape(-1)
                imgs.append(img)
                lbls.append(label)
                clock_lines = ['-', '\', '|', '/']
    
                sys.stdout.write(
                    ''.join((str(np.array(lbls).shape[0]),
                             "-th sample in ",
                             path_res.split('/')[-2],
                             clock_lines[np.array(lbls).shape[0]//100 % 4],
                             '
    ')))
                sys.stdout.flush()
    
            coord.request_stop()
            coord.join(threads)
        return to_categorical(np.array(lbls), num_classes=68), np.array(imgs)
    
    
    def main():
        imgs, labels = tfrecord2array(
            r"./data_tfrecords/integers_tfrecords/test.tfrecords")
        print("imgs.shape:", imgs.shape)
        print("labels.shape:", labels.shape)
    
    
    if __name__ == '__main__':
        main()
    
  • 相关阅读:
    Elasticsearch常用命令
    Linux中使用systemctl操作服务、新建自定义服务
    Windows下安装MongoDB解压版
    Java执行cmd命令、bat脚本、linux命令,shell脚本等
    Ubuntu
    PostgreSQL删除数据库失败处理
    Ubuntu service 命令
    Ubuntu18修改/迁移mysql5.7数据存放路径
    攻防世界-web-ics-02(sql注入、ssrf、目录扫描)
    攻防世界-web-filemanager(源码泄漏、二次注入)
  • 原文地址:https://www.cnblogs.com/ZhengPeng7/p/7942362.html
Copyright © 2011-2022 走看看